From b08dd8b75de7ec1a5053900c9a26eac9ef9a2cb6 Mon Sep 17 00:00:00 2001 From: Jake Swann Date: Sat, 29 Mar 2025 14:08:42 -0400 Subject: [PATCH 1/5] Allow plot_probe not to plot on axes, but just return polycollections --- src/probeinterface/plotting.py | 108 ++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 50 deletions(-) diff --git a/src/probeinterface/plotting.py b/src/probeinterface/plotting.py index 40ee710d..eb6d09ac 100644 --- a/src/probeinterface/plotting.py +++ b/src/probeinterface/plotting.py @@ -28,6 +28,7 @@ def plot_probe( ylims: tuple | None = None, zlims: tuple | None = None, show_channel_on_click: bool = False, + add_to_axis: bool = True, ): """Plot a Probe object. Generates a 2D or 3D axis, depending on Probe.ndim @@ -64,6 +65,9 @@ def plot_probe( Limits for z dimension show_channel_on_click : bool, default: False If True, the channel information is shown upon click + add_to_axis : bool, default: True + If True, collections are added to the axis. If False, collections are + only returned without being added to the axis. Returns ------- @@ -79,14 +83,14 @@ def plot_probe( elif probe.ndim == 3: from mpl_toolkits.mplot3d.art3d import Poly3DCollection - if ax is None: + if ax is None and add_to_axis: if probe.ndim == 2: fig, ax = plt.subplots() ax.set_aspect("equal") else: fig = plt.figure() ax = fig.add_subplot(1, 1, 1, projection="3d") - else: + elif ax is not None: fig = ax.get_figure() _probe_shape_kwargs = dict(facecolor="green", edgecolor="k", lw=0.5, alpha=0.3) @@ -107,16 +111,18 @@ def plot_probe( vertices = probe.get_contact_vertices() if probe.ndim == 2: poly = PolyCollection(vertices, color=contacts_colors, **_contacts_kargs) - ax.add_collection(poly) + if add_to_axis and ax is not None: + ax.add_collection(poly) elif probe.ndim == 3: poly = Poly3DCollection(vertices, color=contacts_colors, **_contacts_kargs) - ax.add_collection3d(poly) + if add_to_axis and ax is not None: + ax.add_collection3d(poly) if contacts_values is not None: poly.set_array(contacts_values) poly.set_cmap(cmap) - if show_channel_on_click: + if show_channel_on_click and add_to_axis: assert probe.ndim == 2, "show_channel_on_click works only for ndim=2" def on_press(event): @@ -126,61 +132,63 @@ def on_press(event): fig.canvas.mpl_connect("button_release_event", on_release) # probe shape + poly_contour = None planar_contour = probe.probe_planar_contour if planar_contour is not None: if probe.ndim == 2: poly_contour = PolyCollection([planar_contour], **_probe_shape_kwargs) - ax.add_collection(poly_contour) + if add_to_axis and ax is not None: + ax.add_collection(poly_contour) elif probe.ndim == 3: poly_contour = Poly3DCollection([planar_contour], **_probe_shape_kwargs) - ax.add_collection3d(poly_contour) - else: - poly_contour = None - - if text_on_contact is not None: - text_on_contact = np.asarray(text_on_contact) - assert text_on_contact.size == probe.get_contact_count() + if add_to_axis and ax is not None: + ax.add_collection3d(poly_contour) + + if add_to_axis and ax is not None: + if text_on_contact is not None: + text_on_contact = np.asarray(text_on_contact) + assert text_on_contact.size == probe.get_contact_count() + + if with_contact_id or with_device_index or text_on_contact is not None: + if probe.ndim == 3: + raise NotImplementedError("Channel index is 2d only") + for i in range(n): + txt = [] + if with_contact_id and probe.contact_ids is not None: + contact_id = probe.contact_ids[i] + txt.append(f"id{contact_id}") + if with_device_index and probe.device_channel_indices is not None: + chan_ind = probe.device_channel_indices[i] + txt.append(f"dev{chan_ind}") + if text_on_contact is not None: + txt.append(f"{text_on_contact[i]}") + + txt = "\n".join(txt) + x, y = probe.contact_positions[i] + ax.text(x, y, txt, ha="center", va="center", clip_on=True) + + if xlims is None or ylims is None or (zlims is None and probe.ndim == 3): + xlims, ylims, zlims = get_auto_lims(probe) + + ax.set_xlim(*xlims) + ax.set_ylim(*ylims) + + if probe.si_units == "um": + unit_str = "($\\mu m$)" + else: + unit_str = f"({probe.si_units})" + ax.set_xlabel(f"x {unit_str}", fontsize=15) + ax.set_ylabel(f"y {unit_str}", fontsize=15) - if with_contact_id or with_device_index or text_on_contact is not None: if probe.ndim == 3: - raise NotImplementedError("Channel index is 2d only") - for i in range(n): - txt = [] - if with_contact_id and probe.contact_ids is not None: - contact_id = probe.contact_ids[i] - txt.append(f"id{contact_id}") - if with_device_index and probe.device_channel_indices is not None: - chan_ind = probe.device_channel_indices[i] - txt.append(f"dev{chan_ind}") - if text_on_contact is not None: - txt.append(f"{text_on_contact[i]}") - - txt = "\n".join(txt) - x, y = probe.contact_positions[i] - ax.text(x, y, txt, ha="center", va="center", clip_on=True) - - if xlims is None or ylims is None or (zlims is None and probe.ndim == 3): - xlims, ylims, zlims = get_auto_lims(probe) - - ax.set_xlim(*xlims) - ax.set_ylim(*ylims) - - if probe.si_units == "um": - unit_str = "($\\mu m$)" - else: - unit_str = f"({probe.si_units})" - ax.set_xlabel(f"x {unit_str}", fontsize=15) - ax.set_ylabel(f"y {unit_str}", fontsize=15) + ax.set_zlim(zlims) + ax.set_zlabel("z") - if probe.ndim == 3: - ax.set_zlim(zlims) - ax.set_zlabel("z") - - if probe.ndim == 2: - ax.set_aspect("equal") + if probe.ndim == 2: + ax.set_aspect("equal") - if title: - ax.set_title(probe.get_title()) + if title: + ax.set_title(probe.get_title()) return poly, poly_contour From 95c135669d2d4396725678fac10ea4fcaf95bf27 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 29 Mar 2025 18:20:06 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/probeinterface/plotting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/probeinterface/plotting.py b/src/probeinterface/plotting.py index eb6d09ac..0d8dae67 100644 --- a/src/probeinterface/plotting.py +++ b/src/probeinterface/plotting.py @@ -66,7 +66,7 @@ def plot_probe( show_channel_on_click : bool, default: False If True, the channel information is shown upon click add_to_axis : bool, default: True - If True, collections are added to the axis. If False, collections are + If True, collections are added to the axis. If False, collections are only returned without being added to the axis. Returns From 4e787cb4814947975a35071578ef8d188273c721 Mon Sep 17 00:00:00 2001 From: Jake Swann Date: Tue, 1 Apr 2025 06:01:09 -0400 Subject: [PATCH 3/5] Refactor polycollection creation to separate function --- src/probeinterface/plotting.py | 224 +++++++++++++++++++-------------- 1 file changed, 132 insertions(+), 92 deletions(-) diff --git a/src/probeinterface/plotting.py b/src/probeinterface/plotting.py index 0d8dae67..ce048e4e 100644 --- a/src/probeinterface/plotting.py +++ b/src/probeinterface/plotting.py @@ -12,6 +12,77 @@ from .utils import get_auto_lims +def create_probe_collections( + probe, + contacts_colors: list | None = None, + contacts_values: np.ndarray | None = None, + cmap: str = "viridis", + contacts_kargs: dict = {}, + probe_shape_kwargs: dict = {}, +): + """Create PolyCollection objects for a Probe. + + Parameters + ---------- + probe : Probe + The probe object + contacts_colors : matplotlib color | None, default: None + The color of the contacts + contacts_values : np.ndarray | None, default: None + Values to color the contacts with + cmap : str, default: "viridis" + A colormap color + contacts_kargs : dict, default: {} + Dict with kwargs for contacts (e.g. alpha, edgecolor, lw) + probe_shape_kwargs : dict, default: {} + Dict with kwargs for probe shape (e.g. alpha, edgecolor, lw) + + Returns + ------- + poly : PolyCollection + The polygon collection for contacts + poly_contour : PolyCollection | None + The polygon collection for the probe shape + """ + if probe.ndim == 2: + from matplotlib.collections import PolyCollection + Collection = PolyCollection + elif probe.ndim == 3: + from mpl_toolkits.mplot3d.art3d import Poly3DCollection + Collection = Poly3DCollection + else: + raise ValueError(f"Unexpected probe.ndim: {probe.ndim}") + + _probe_shape_kwargs = dict(facecolor="green", edgecolor="k", lw=0.5, alpha=0.3) + _probe_shape_kwargs.update(probe_shape_kwargs) + + _contacts_kargs = dict(alpha=0.7, edgecolor=[0.3, 0.3, 0.3], lw=0.5) + _contacts_kargs.update(contacts_kargs) + + n = probe.get_contact_count() + + if contacts_colors is None and contacts_values is None: + contacts_colors = ["orange"] * n + elif contacts_colors is not None: + contacts_colors = contacts_colors + elif contacts_values is not None: + contacts_colors = None + + vertices = probe.get_contact_vertices() + poly = Collection(vertices, color=contacts_colors, **_contacts_kargs) + + if contacts_values is not None: + poly.set_array(contacts_values) + poly.set_cmap(cmap) + + # probe shape + poly_contour = None + planar_contour = probe.probe_planar_contour + if planar_contour is not None: + poly_contour = Collection([planar_contour], **_probe_shape_kwargs) + + return poly, poly_contour + def plot_probe( probe, ax=None, @@ -28,7 +99,6 @@ def plot_probe( ylims: tuple | None = None, zlims: tuple | None = None, show_channel_on_click: bool = False, - add_to_axis: bool = True, ): """Plot a Probe object. Generates a 2D or 3D axis, depending on Probe.ndim @@ -65,9 +135,6 @@ def plot_probe( Limits for z dimension show_channel_on_click : bool, default: False If True, the channel information is shown upon click - add_to_axis : bool, default: True - If True, collections are added to the axis. If False, collections are - only returned without being added to the axis. Returns ------- @@ -78,51 +145,37 @@ def plot_probe( """ import matplotlib.pyplot as plt - if probe.ndim == 2: - from matplotlib.collections import PolyCollection - elif probe.ndim == 3: - from mpl_toolkits.mplot3d.art3d import Poly3DCollection - - if ax is None and add_to_axis: + if ax is None: if probe.ndim == 2: fig, ax = plt.subplots() ax.set_aspect("equal") else: fig = plt.figure() ax = fig.add_subplot(1, 1, 1, projection="3d") - elif ax is not None: + else: fig = ax.get_figure() - _probe_shape_kwargs = dict(facecolor="green", edgecolor="k", lw=0.5, alpha=0.3) - _probe_shape_kwargs.update(probe_shape_kwargs) - - _contacts_kargs = dict(alpha=0.7, edgecolor=[0.3, 0.3, 0.3], lw=0.5) - _contacts_kargs.update(contacts_kargs) - - n = probe.get_contact_count() - - if contacts_colors is None and contacts_values is None: - contacts_colors = ["orange"] * n - elif contacts_colors is not None: - contacts_colors = contacts_colors - elif contacts_values is not None: - contacts_colors = None + # Create collections (contacts, probe shape) + poly, poly_contour = create_probe_collections( + probe, + contacts_colors=contacts_colors, + contacts_values=contacts_values, + cmap=cmap, + contacts_kargs=contacts_kargs, + probe_shape_kwargs=probe_shape_kwargs, + ) - vertices = probe.get_contact_vertices() + # Add collections to the axis if probe.ndim == 2: - poly = PolyCollection(vertices, color=contacts_colors, **_contacts_kargs) - if add_to_axis and ax is not None: - ax.add_collection(poly) + ax.add_collection(poly) + if poly_contour is not None: + ax.add_collection(poly_contour) elif probe.ndim == 3: - poly = Poly3DCollection(vertices, color=contacts_colors, **_contacts_kargs) - if add_to_axis and ax is not None: - ax.add_collection3d(poly) - - if contacts_values is not None: - poly.set_array(contacts_values) - poly.set_cmap(cmap) - - if show_channel_on_click and add_to_axis: + ax.add_collection3d(poly) + if poly_contour is not None: + ax.add_collection3d(poly_contour) + + if show_channel_on_click: assert probe.ndim == 2, "show_channel_on_click works only for ndim=2" def on_press(event): @@ -131,64 +184,51 @@ def on_press(event): fig.canvas.mpl_connect("button_press_event", on_press) fig.canvas.mpl_connect("button_release_event", on_release) - # probe shape - poly_contour = None - planar_contour = probe.probe_planar_contour - if planar_contour is not None: - if probe.ndim == 2: - poly_contour = PolyCollection([planar_contour], **_probe_shape_kwargs) - if add_to_axis and ax is not None: - ax.add_collection(poly_contour) - elif probe.ndim == 3: - poly_contour = Poly3DCollection([planar_contour], **_probe_shape_kwargs) - if add_to_axis and ax is not None: - ax.add_collection3d(poly_contour) - - if add_to_axis and ax is not None: - if text_on_contact is not None: - text_on_contact = np.asarray(text_on_contact) - assert text_on_contact.size == probe.get_contact_count() - - if with_contact_id or with_device_index or text_on_contact is not None: - if probe.ndim == 3: - raise NotImplementedError("Channel index is 2d only") - for i in range(n): - txt = [] - if with_contact_id and probe.contact_ids is not None: - contact_id = probe.contact_ids[i] - txt.append(f"id{contact_id}") - if with_device_index and probe.device_channel_indices is not None: - chan_ind = probe.device_channel_indices[i] - txt.append(f"dev{chan_ind}") - if text_on_contact is not None: - txt.append(f"{text_on_contact[i]}") - - txt = "\n".join(txt) - x, y = probe.contact_positions[i] - ax.text(x, y, txt, ha="center", va="center", clip_on=True) - - if xlims is None or ylims is None or (zlims is None and probe.ndim == 3): - xlims, ylims, zlims = get_auto_lims(probe) - - ax.set_xlim(*xlims) - ax.set_ylim(*ylims) - - if probe.si_units == "um": - unit_str = "($\\mu m$)" - else: - unit_str = f"({probe.si_units})" - ax.set_xlabel(f"x {unit_str}", fontsize=15) - ax.set_ylabel(f"y {unit_str}", fontsize=15) + if text_on_contact is not None: + text_on_contact = np.asarray(text_on_contact) + assert text_on_contact.size == probe.get_contact_count() + n = probe.get_contact_count() + if with_contact_id or with_device_index or text_on_contact is not None: if probe.ndim == 3: - ax.set_zlim(zlims) - ax.set_zlabel("z") + raise NotImplementedError("Channel index is 2d only") + for i in range(n): + txt = [] + if with_contact_id and probe.contact_ids is not None: + contact_id = probe.contact_ids[i] + txt.append(f"id{contact_id}") + if with_device_index and probe.device_channel_indices is not None: + chan_ind = probe.device_channel_indices[i] + txt.append(f"dev{chan_ind}") + if text_on_contact is not None: + txt.append(f"{text_on_contact[i]}") + + txt = "\n".join(txt) + x, y = probe.contact_positions[i] + ax.text(x, y, txt, ha="center", va="center", clip_on=True) + + if xlims is None or ylims is None or (zlims is None and probe.ndim == 3): + xlims, ylims, zlims = get_auto_lims(probe) + + ax.set_xlim(*xlims) + ax.set_ylim(*ylims) + + if probe.si_units == "um": + unit_str = "($\\mu m$)" + else: + unit_str = f"({probe.si_units})" + ax.set_xlabel(f"x {unit_str}", fontsize=15) + ax.set_ylabel(f"y {unit_str}", fontsize=15) - if probe.ndim == 2: - ax.set_aspect("equal") + if probe.ndim == 3: + ax.set_zlim(zlims) + ax.set_zlabel("z") + + if probe.ndim == 2: + ax.set_aspect("equal") - if title: - ax.set_title(probe.get_title()) + if title: + ax.set_title(probe.get_title()) return poly, poly_contour From bfec1d1bae11c3cd9a09238f10e43465bce31b29 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 1 Apr 2025 10:01:19 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/probeinterface/plotting.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/probeinterface/plotting.py b/src/probeinterface/plotting.py index ce048e4e..482c49b7 100644 --- a/src/probeinterface/plotting.py +++ b/src/probeinterface/plotting.py @@ -21,7 +21,7 @@ def create_probe_collections( probe_shape_kwargs: dict = {}, ): """Create PolyCollection objects for a Probe. - + Parameters ---------- probe : Probe @@ -36,7 +36,7 @@ def create_probe_collections( Dict with kwargs for contacts (e.g. alpha, edgecolor, lw) probe_shape_kwargs : dict, default: {} Dict with kwargs for probe shape (e.g. alpha, edgecolor, lw) - + Returns ------- poly : PolyCollection @@ -46,9 +46,11 @@ def create_probe_collections( """ if probe.ndim == 2: from matplotlib.collections import PolyCollection + Collection = PolyCollection elif probe.ndim == 3: from mpl_toolkits.mplot3d.art3d import Poly3DCollection + Collection = Poly3DCollection else: raise ValueError(f"Unexpected probe.ndim: {probe.ndim}") @@ -70,7 +72,7 @@ def create_probe_collections( vertices = probe.get_contact_vertices() poly = Collection(vertices, color=contacts_colors, **_contacts_kargs) - + if contacts_values is not None: poly.set_array(contacts_values) poly.set_cmap(cmap) @@ -80,9 +82,10 @@ def create_probe_collections( planar_contour = probe.probe_planar_contour if planar_contour is not None: poly_contour = Collection([planar_contour], **_probe_shape_kwargs) - + return poly, poly_contour + def plot_probe( probe, ax=None, @@ -174,7 +177,7 @@ def plot_probe( ax.add_collection3d(poly) if poly_contour is not None: ax.add_collection3d(poly_contour) - + if show_channel_on_click: assert probe.ndim == 2, "show_channel_on_click works only for ndim=2" From 4604a7e76339c37ca3a305d5c71c161992944126 Mon Sep 17 00:00:00 2001 From: Jake Swann Date: Tue, 1 Apr 2025 06:04:33 -0400 Subject: [PATCH 5/5] rename function for clarity --- src/probeinterface/plotting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/probeinterface/plotting.py b/src/probeinterface/plotting.py index 482c49b7..da84dc1b 100644 --- a/src/probeinterface/plotting.py +++ b/src/probeinterface/plotting.py @@ -12,7 +12,7 @@ from .utils import get_auto_lims -def create_probe_collections( +def create_probe_polygons( probe, contacts_colors: list | None = None, contacts_values: np.ndarray | None = None, @@ -159,7 +159,7 @@ def plot_probe( fig = ax.get_figure() # Create collections (contacts, probe shape) - poly, poly_contour = create_probe_collections( + poly, poly_contour = create_probe_polygons( probe, contacts_colors=contacts_colors, contacts_values=contacts_values,