Skip to content

Allow plot_probe not to plot on axes, but just return polycollections #334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 89 additions & 38 deletions src/probeinterface/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,80 @@
from .utils import get_auto_lims


def create_probe_polygons(
probe,
contacts_colors: list | None = None,
contacts_values: np.ndarray | None = None,
cmap: str = "viridis",
contacts_kargs: dict = {},
probe_shape_kwargs: dict = {},
):
"""Create PolyCollection objects for a Probe.

Parameters
----------
probe : Probe
The probe object
contacts_colors : matplotlib color | None, default: None
The color of the contacts
contacts_values : np.ndarray | None, default: None
Values to color the contacts with
cmap : str, default: "viridis"
A colormap color
contacts_kargs : dict, default: {}
Dict with kwargs for contacts (e.g. alpha, edgecolor, lw)
probe_shape_kwargs : dict, default: {}
Dict with kwargs for probe shape (e.g. alpha, edgecolor, lw)

Returns
-------
poly : PolyCollection
The polygon collection for contacts
poly_contour : PolyCollection | None
The polygon collection for the probe shape
"""
if probe.ndim == 2:
from matplotlib.collections import PolyCollection

Collection = PolyCollection
elif probe.ndim == 3:
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

Collection = Poly3DCollection
else:
raise ValueError(f"Unexpected probe.ndim: {probe.ndim}")

Check warning on line 56 in src/probeinterface/plotting.py

View check run for this annotation

Codecov / codecov/patch

src/probeinterface/plotting.py#L56

Added line #L56 was not covered by tests

_probe_shape_kwargs = dict(facecolor="green", edgecolor="k", lw=0.5, alpha=0.3)
_probe_shape_kwargs.update(probe_shape_kwargs)

_contacts_kargs = dict(alpha=0.7, edgecolor=[0.3, 0.3, 0.3], lw=0.5)
_contacts_kargs.update(contacts_kargs)

n = probe.get_contact_count()

if contacts_colors is None and contacts_values is None:
contacts_colors = ["orange"] * n
elif contacts_colors is not None:
contacts_colors = contacts_colors
elif contacts_values is not None:
contacts_colors = None

Check warning on line 71 in src/probeinterface/plotting.py

View check run for this annotation

Codecov / codecov/patch

src/probeinterface/plotting.py#L70-L71

Added lines #L70 - L71 were not covered by tests

vertices = probe.get_contact_vertices()
poly = Collection(vertices, color=contacts_colors, **_contacts_kargs)

if contacts_values is not None:
poly.set_array(contacts_values)
poly.set_cmap(cmap)

Check warning on line 78 in src/probeinterface/plotting.py

View check run for this annotation

Codecov / codecov/patch

src/probeinterface/plotting.py#L77-L78

Added lines #L77 - L78 were not covered by tests

# probe shape
poly_contour = None
planar_contour = probe.probe_planar_contour
if planar_contour is not None:
poly_contour = Collection([planar_contour], **_probe_shape_kwargs)

return poly, poly_contour


def plot_probe(
probe,
ax=None,
Expand Down Expand Up @@ -74,11 +148,6 @@
"""
import matplotlib.pyplot as plt

if probe.ndim == 2:
from matplotlib.collections import PolyCollection
elif probe.ndim == 3:
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

if ax is None:
if probe.ndim == 2:
fig, ax = plt.subplots()
Expand All @@ -89,32 +158,25 @@
else:
fig = ax.get_figure()

_probe_shape_kwargs = dict(facecolor="green", edgecolor="k", lw=0.5, alpha=0.3)
_probe_shape_kwargs.update(probe_shape_kwargs)

_contacts_kargs = dict(alpha=0.7, edgecolor=[0.3, 0.3, 0.3], lw=0.5)
_contacts_kargs.update(contacts_kargs)

n = probe.get_contact_count()

if contacts_colors is None and contacts_values is None:
contacts_colors = ["orange"] * n
elif contacts_colors is not None:
contacts_colors = contacts_colors
elif contacts_values is not None:
contacts_colors = None
# Create collections (contacts, probe shape)
poly, poly_contour = create_probe_polygons(
probe,
contacts_colors=contacts_colors,
contacts_values=contacts_values,
cmap=cmap,
contacts_kargs=contacts_kargs,
probe_shape_kwargs=probe_shape_kwargs,
)

vertices = probe.get_contact_vertices()
# Add collections to the axis
if probe.ndim == 2:
poly = PolyCollection(vertices, color=contacts_colors, **_contacts_kargs)
ax.add_collection(poly)
if poly_contour is not None:
ax.add_collection(poly_contour)
elif probe.ndim == 3:
poly = Poly3DCollection(vertices, color=contacts_colors, **_contacts_kargs)
ax.add_collection3d(poly)

if contacts_values is not None:
poly.set_array(contacts_values)
poly.set_cmap(cmap)
if poly_contour is not None:
ax.add_collection3d(poly_contour)

if show_channel_on_click:
assert probe.ndim == 2, "show_channel_on_click works only for ndim=2"
Expand All @@ -125,22 +187,11 @@
fig.canvas.mpl_connect("button_press_event", on_press)
fig.canvas.mpl_connect("button_release_event", on_release)

# probe shape
planar_contour = probe.probe_planar_contour
if planar_contour is not None:
if probe.ndim == 2:
poly_contour = PolyCollection([planar_contour], **_probe_shape_kwargs)
ax.add_collection(poly_contour)
elif probe.ndim == 3:
poly_contour = Poly3DCollection([planar_contour], **_probe_shape_kwargs)
ax.add_collection3d(poly_contour)
else:
poly_contour = None

if text_on_contact is not None:
text_on_contact = np.asarray(text_on_contact)
assert text_on_contact.size == probe.get_contact_count()

n = probe.get_contact_count()
if with_contact_id or with_device_index or text_on_contact is not None:
if probe.ndim == 3:
raise NotImplementedError("Channel index is 2d only")
Expand Down