diff --git a/src/probeinterface/plotting.py b/src/probeinterface/plotting.py index 40ee710d..da84dc1b 100644 --- a/src/probeinterface/plotting.py +++ b/src/probeinterface/plotting.py @@ -12,6 +12,80 @@ from .utils import get_auto_lims +def create_probe_polygons( + probe, + contacts_colors: list | None = None, + contacts_values: np.ndarray | None = None, + cmap: str = "viridis", + contacts_kargs: dict = {}, + probe_shape_kwargs: dict = {}, +): + """Create PolyCollection objects for a Probe. + + Parameters + ---------- + probe : Probe + The probe object + contacts_colors : matplotlib color | None, default: None + The color of the contacts + contacts_values : np.ndarray | None, default: None + Values to color the contacts with + cmap : str, default: "viridis" + A colormap color + contacts_kargs : dict, default: {} + Dict with kwargs for contacts (e.g. alpha, edgecolor, lw) + probe_shape_kwargs : dict, default: {} + Dict with kwargs for probe shape (e.g. alpha, edgecolor, lw) + + Returns + ------- + poly : PolyCollection + The polygon collection for contacts + poly_contour : PolyCollection | None + The polygon collection for the probe shape + """ + if probe.ndim == 2: + from matplotlib.collections import PolyCollection + + Collection = PolyCollection + elif probe.ndim == 3: + from mpl_toolkits.mplot3d.art3d import Poly3DCollection + + Collection = Poly3DCollection + else: + raise ValueError(f"Unexpected probe.ndim: {probe.ndim}") + + _probe_shape_kwargs = dict(facecolor="green", edgecolor="k", lw=0.5, alpha=0.3) + _probe_shape_kwargs.update(probe_shape_kwargs) + + _contacts_kargs = dict(alpha=0.7, edgecolor=[0.3, 0.3, 0.3], lw=0.5) + _contacts_kargs.update(contacts_kargs) + + n = probe.get_contact_count() + + if contacts_colors is None and contacts_values is None: + contacts_colors = ["orange"] * n + elif contacts_colors is not None: + contacts_colors = contacts_colors + elif contacts_values is not None: + contacts_colors = None + + vertices = probe.get_contact_vertices() + poly = Collection(vertices, color=contacts_colors, **_contacts_kargs) + + if contacts_values is not None: + poly.set_array(contacts_values) + poly.set_cmap(cmap) + + # probe shape + poly_contour = None + planar_contour = probe.probe_planar_contour + if planar_contour is not None: + poly_contour = Collection([planar_contour], **_probe_shape_kwargs) + + return poly, poly_contour + + def plot_probe( probe, ax=None, @@ -74,11 +148,6 @@ def plot_probe( """ import matplotlib.pyplot as plt - if probe.ndim == 2: - from matplotlib.collections import PolyCollection - elif probe.ndim == 3: - from mpl_toolkits.mplot3d.art3d import Poly3DCollection - if ax is None: if probe.ndim == 2: fig, ax = plt.subplots() @@ -89,32 +158,25 @@ def plot_probe( else: fig = ax.get_figure() - _probe_shape_kwargs = dict(facecolor="green", edgecolor="k", lw=0.5, alpha=0.3) - _probe_shape_kwargs.update(probe_shape_kwargs) - - _contacts_kargs = dict(alpha=0.7, edgecolor=[0.3, 0.3, 0.3], lw=0.5) - _contacts_kargs.update(contacts_kargs) - - n = probe.get_contact_count() - - if contacts_colors is None and contacts_values is None: - contacts_colors = ["orange"] * n - elif contacts_colors is not None: - contacts_colors = contacts_colors - elif contacts_values is not None: - contacts_colors = None + # Create collections (contacts, probe shape) + poly, poly_contour = create_probe_polygons( + probe, + contacts_colors=contacts_colors, + contacts_values=contacts_values, + cmap=cmap, + contacts_kargs=contacts_kargs, + probe_shape_kwargs=probe_shape_kwargs, + ) - vertices = probe.get_contact_vertices() + # Add collections to the axis if probe.ndim == 2: - poly = PolyCollection(vertices, color=contacts_colors, **_contacts_kargs) ax.add_collection(poly) + if poly_contour is not None: + ax.add_collection(poly_contour) elif probe.ndim == 3: - poly = Poly3DCollection(vertices, color=contacts_colors, **_contacts_kargs) ax.add_collection3d(poly) - - if contacts_values is not None: - poly.set_array(contacts_values) - poly.set_cmap(cmap) + if poly_contour is not None: + ax.add_collection3d(poly_contour) if show_channel_on_click: assert probe.ndim == 2, "show_channel_on_click works only for ndim=2" @@ -125,22 +187,11 @@ def on_press(event): fig.canvas.mpl_connect("button_press_event", on_press) fig.canvas.mpl_connect("button_release_event", on_release) - # probe shape - planar_contour = probe.probe_planar_contour - if planar_contour is not None: - if probe.ndim == 2: - poly_contour = PolyCollection([planar_contour], **_probe_shape_kwargs) - ax.add_collection(poly_contour) - elif probe.ndim == 3: - poly_contour = Poly3DCollection([planar_contour], **_probe_shape_kwargs) - ax.add_collection3d(poly_contour) - else: - poly_contour = None - if text_on_contact is not None: text_on_contact = np.asarray(text_on_contact) assert text_on_contact.size == probe.get_contact_count() + n = probe.get_contact_count() if with_contact_id or with_device_index or text_on_contact is not None: if probe.ndim == 3: raise NotImplementedError("Channel index is 2d only")