diff --git a/doc/changes/devel/12071.newfeature.rst b/doc/changes/devel/12071.newfeature.rst new file mode 100644 index 00000000000..4e7995e3beb --- /dev/null +++ b/doc/changes/devel/12071.newfeature.rst @@ -0,0 +1 @@ +Add new ``select`` parameter to :func:`mne.viz.plot_evoked_topo` and :meth:`mne.Evoked.plot_topo` to toggle lasso selection of sensors, by `Marijn van Vliet`_. diff --git a/mne/epochs.py b/mne/epochs.py index 679643ab969..ee8921d3990 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -1353,6 +1353,7 @@ def plot_topo_image( fig_facecolor="k", fig_background=None, font_color="w", + select=False, show=True, ): return plot_topo_image_epochs( @@ -1371,6 +1372,7 @@ def plot_topo_image( fig_facecolor=fig_facecolor, fig_background=fig_background, font_color=font_color, + select=select, show=show, ) diff --git a/mne/evoked.py b/mne/evoked.py index c04f83531e3..7bd2355e4ee 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -613,6 +613,7 @@ def plot_topo( background_color="w", noise_cov=None, exclude="bads", + select=False, show=True, ): """. @@ -639,6 +640,7 @@ def plot_topo( background_color=background_color, noise_cov=noise_cov, exclude=exclude, + select=select, show=show, ) diff --git a/mne/viz/_figure.py b/mne/viz/_figure.py index b63d2a395e2..f492c4b7fde 100644 --- a/mne/viz/_figure.py +++ b/mne/viz/_figure.py @@ -500,11 +500,11 @@ def _create_ch_location_fig(self, pick): show=False, ) # highlight desired channel & disable interactivity - inds = np.isin(fig.lasso.ch_names, [ch_name]) + fig.lasso.selection_inds = np.isin(fig.lasso.names, [ch_name]) fig.lasso.disconnect() - fig.lasso.alpha_other = 0.3 + fig.lasso.alpha_nonselected = 0.3 fig.lasso.linewidth_selected = 3 - fig.lasso.style_sensors(inds) + fig.lasso.style_objects() return fig diff --git a/mne/viz/_mpl_figure.py b/mne/viz/_mpl_figure.py index 2e552bd4012..3987b641dff 100644 --- a/mne/viz/_mpl_figure.py +++ b/mne/viz/_mpl_figure.py @@ -1536,7 +1536,7 @@ def _update_selection(self): def _update_highlighted_sensors(self): """Update the sensor plot to show what is selected.""" inds = np.isin( - self.mne.fig_selection.lasso.ch_names, self.mne.ch_names[self.mne.picks] + self.mne.fig_selection.lasso.names, self.mne.ch_names[self.mne.picks] ).nonzero()[0] self.mne.fig_selection.lasso.select_many(inds) diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index b047de4ea32..96ee0684e6e 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -1153,6 +1153,7 @@ def plot_evoked_topo( background_color="w", noise_cov=None, exclude="bads", + select=False, show=True, ): """Plot 2D topography of evoked responses. @@ -1218,6 +1219,15 @@ def plot_evoked_topo( exclude : list of str | ``'bads'`` Channels names to exclude from the plot. If ``'bads'``, the bad channels are excluded. By default, exclude is set to ``'bads'``. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. + + .. versionadded:: 1.10.0 + exclude : list of str | ``'bads'`` + Channels names to exclude from the plot. If ``'bads'``, the + bad channels are excluded. By default, exclude is set to ``'bads'``. show : bool Show figure if True. @@ -1274,10 +1284,11 @@ def plot_evoked_topo( font_color=font_color, merge_channels=merge_grads, legend=legend, + noise_cov=noise_cov, axes=axes, exclude=exclude, + select=select, show=show, - noise_cov=noise_cov, ) diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index 89e0a7c543d..caa09ae4d07 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -1088,36 +1088,25 @@ def test_plot_sensors(raw): pytest.raises(TypeError, plot_sensors, raw) # needs to be info pytest.raises(ValueError, plot_sensors, raw.info, kind="sasaasd") plt.close("all") + + # Test lasso selection. fig, sels = raw.plot_sensors("select", show_names=True) ax = fig.axes[0] - - # Click with no sensors - _fake_click(fig, ax, (0.0, 0.0), xform="data") - _fake_click(fig, ax, (0, 0.0), xform="data", kind="release") - assert fig.lasso.selection == [] - - # Lasso with 1 sensor (upper left) - _fake_click(fig, ax, (0, 1), xform="ax") - fig.canvas.draw() - assert fig.lasso.selection == [] - _fake_click(fig, ax, (0.65, 1), xform="ax", kind="motion") - _fake_click(fig, ax, (0.65, 0.7), xform="ax", kind="motion") - _fake_keypress(fig, "control") - _fake_click(fig, ax, (0, 0.7), xform="ax", kind="release", key="control") + # Lasso a single sensor. + _fake_click(fig, ax, (-0.13, 0.13), xform="data") + _fake_click(fig, ax, (-0.11, 0.13), xform="data", kind="motion") + _fake_click(fig, ax, (-0.11, 0.06), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.06), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="release") assert fig.lasso.selection == ["MEG 0121"] - # check that point appearance changes - fc = fig.lasso.collection.get_facecolors() - ec = fig.lasso.collection.get_edgecolors() - assert (fc[:, -1] == [0.5, 1.0, 0.5]).all() - assert (ec[:, -1] == [0.25, 1.0, 0.25]).all() - - _fake_click(fig, ax, (0.7, 1), xform="ax", kind="motion", key="control") - xy = ax.collections[0].get_offsets() - _fake_click(fig, ax, xy[2], xform="data", key="control") # single sel + # Add another sensor with a single click. + _fake_keypress(fig, "control") + _fake_click(fig, ax, (-0.1278, 0.0318), xform="data") + _fake_click(fig, ax, (-0.1278, 0.0318), xform="data", kind="release") + _fake_keypress(fig, "control", kind="release") assert fig.lasso.selection == ["MEG 0121", "MEG 0131"] - _fake_click(fig, ax, xy[2], xform="data", key="control") # deselect - assert fig.lasso.selection == ["MEG 0121"] plt.close("all") raw.info["dev_head_t"] = None # like empty room diff --git a/mne/viz/tests/test_topo.py b/mne/viz/tests/test_topo.py index 85b4b43dcf8..48d031739b9 100644 --- a/mne/viz/tests/test_topo.py +++ b/mne/viz/tests/test_topo.py @@ -23,7 +23,7 @@ ) from mne.viz.evoked import _line_plot_onselect from mne.viz.topo import _imshow_tfr, _plot_update_evoked_topo_proj, iter_topography -from mne.viz.utils import _fake_click +from mne.viz.utils import _fake_click, _fake_keypress base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" evoked_fname = base_dir / "test-ave.fif" @@ -231,6 +231,16 @@ def test_plot_topo(): break plt.close("all") + # Test plot_topo with selection of channels enabled. + fig = evoked.plot_topo(select=True) + ax = fig.axes[0] + _fake_click(fig, ax, (0.05, 0.62), xform="data") + _fake_click(fig, ax, (0.2, 0.62), xform="data", kind="motion") + _fake_click(fig, ax, (0.2, 0.7), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.7), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.7), xform="data", kind="release") + assert fig.lasso.selection == ["MEG 0113", "MEG 0112", "MEG 0111"] + def test_plot_topo_nirs(fnirs_evoked): """Test plotting of ERP topography for nirs data.""" @@ -296,6 +306,30 @@ def test_plot_topo_image_epochs(): assert qm_cmap[0] is cmap +def test_plot_topo_select(): + """Test selecting sensors in an ERP topography plot.""" + # Show topography + evoked = _get_epochs().average() + fig = plot_evoked_topo(evoked, select=True) + ax = fig.axes[0] + + # Lasso select 3 out of the 6 sensors. + _fake_click(fig, ax, (0.05, 0.5), xform="data") + _fake_click(fig, ax, (0.2, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.2, 0.6), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.6), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.5), xform="data", kind="release") + assert fig.lasso.selection == ["MEG 0132", "MEG 0133", "MEG 0131"] + + # Add another sensor with a single click. + _fake_keypress(fig, "control") + _fake_click(fig, ax, (0.11, 0.65), xform="data") + _fake_click(fig, ax, (0.21, 0.65), xform="data", kind="release") + _fake_keypress(fig, "control", kind="release") + assert fig.lasso.selection == ["MEG 0111", "MEG 0132", "MEG 0133", "MEG 0131"] + + def test_plot_tfr_topo(): """Test plotting of TFR data.""" epochs = _get_epochs() diff --git a/mne/viz/tests/test_utils.py b/mne/viz/tests/test_utils.py index 59e2976e464..55dc0f1e65c 100644 --- a/mne/viz/tests/test_utils.py +++ b/mne/viz/tests/test_utils.py @@ -16,6 +16,7 @@ from mne.viz import ClickableImage, add_background_image, mne_analyze_colormap from mne.viz.ui_events import ColormapRange, link, subscribe from mne.viz.utils import ( + SelectFromCollection, _compute_scalings, _fake_click, _fake_keypress, @@ -274,3 +275,71 @@ def callback(event): cmap_new1 = fig.axes[0].CB.mappable.get_cmap().name cmap_new2 = fig2.axes[0].CB.mappable.get_cmap().name assert cmap_new1 == cmap_new2 == cmap_want != cmap_old + + +def test_select_from_collection(): + """Test the lasso selector for matplotlib figures.""" + fig, ax = plt.subplots() + collection = ax.scatter([1, 2, 2, 1], [1, 1, 0, 0], color="black", edgecolor="red") + ax.set_xlim(-1, 4) + ax.set_ylim(-1, 2) + lasso = SelectFromCollection(ax, collection, names=["A", "B", "C", "D"]) + assert lasso.selection == [] + + # Make a selection with no patches inside of it. + _fake_click(fig, ax, (0, 0), xform="data") + _fake_click(fig, ax, (0.5, 0), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1), xform="data", kind="release") + assert lasso.selection == [] + + # Doing a single click on a patch should not select it. + _fake_click(fig, ax, (1, 1), xform="data") + assert lasso.selection == [] + + # Make a selection with two patches in it. + _fake_click(fig, ax, (0, 0.5), xform="data") + _fake_click(fig, ax, (3, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (3, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0, 0.5), xform="data", kind="release") + assert lasso.selection == ["A", "B"] + + # Use Control key to lasso an additional patch. + _fake_keypress(fig, "control") + _fake_click(fig, ax, (0.5, -0.5), xform="data") + _fake_click(fig, ax, (1.5, -0.5), xform="data", kind="motion") + _fake_click(fig, ax, (1.5, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 0.5), xform="data", kind="release") + _fake_keypress(fig, "control", kind="release") + assert lasso.selection == ["A", "B", "D"] + + # Use CTRL+SHIFT to remove a patch. + _fake_keypress(fig, "ctrl+shift") + _fake_click(fig, ax, (0.5, 0.5), xform="data") + _fake_click(fig, ax, (1.5, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (1.5, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1.5), xform="data", kind="release") + _fake_keypress(fig, "ctrl+shift", kind="release") + assert lasso.selection == ["B", "D"] + + # Check that the two selected patches have a different appearance. + fc = lasso.collection.get_facecolors() + ec = lasso.collection.get_edgecolors() + assert (fc[:, -1] == [0.5, 1.0, 0.5, 1.0]).all() + assert (ec[:, -1] == [0.25, 1.0, 0.25, 1.0]).all() + + # Test adding and removing single channels. + lasso.select_one(2) # should not do anything without modifier keys + assert lasso.selection == ["B", "D"] + _fake_keypress(fig, "control") + lasso.select_one(2) # add to selection + _fake_keypress(fig, "control", kind="release") + assert lasso.selection == ["B", "C", "D"] + _fake_keypress(fig, "ctrl+shift") + lasso.select_one(1) # remove from selection + assert lasso.selection == ["C", "D"] + _fake_keypress(fig, "ctrl+shift", kind="release") diff --git a/mne/viz/topo.py b/mne/viz/topo.py index 3364a455aed..5c43d4de48e 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -13,8 +13,10 @@ from .._fiff.pick import _picks_to_idx, channel_type, pick_types from ..defaults import _handle_default from ..utils import Bunch, _check_option, _clean_names, _is_numeric, _to_rgb, fill_doc +from .ui_events import ChannelsSelect, publish, subscribe from .utils import ( DraggableColorbar, + SelectFromCollection, _check_cov, _check_delayed_ssp, _draw_proj_checkbox, @@ -37,6 +39,7 @@ def iter_topography( axis_spinecolor="k", layout_scale=None, legend=False, + select=False, ): """Create iterator over channel positions. @@ -72,6 +75,12 @@ def iter_topography( If True, an additional axis is created in the bottom right corner that can be used to, e.g., construct a legend. The index of this axis will be -1. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. + + .. versionadded:: 1.10.0 Returns ------- @@ -93,6 +102,7 @@ def iter_topography( axis_spinecolor, layout_scale, legend=legend, + select=select, ) @@ -128,6 +138,7 @@ def _iter_topography( img=False, axes=None, legend=False, + select=False, ): """Iterate over topography. @@ -193,8 +204,11 @@ def format_coord_multiaxis(x, y, ch_name=None): under_ax.set(xlim=[0, 1], ylim=[0, 1]) axs = list() + + shown_ch_names = [] for idx, name in iter_ch: ch_idx = ch_names.index(name) + shown_ch_names.append(name) if not unified: # old, slow way ax = plt.axes(pos[idx]) ax.patch.set_facecolor(axis_facecolor) @@ -226,24 +240,48 @@ def format_coord_multiaxis(x, y, ch_name=None): if unified: under_ax._mne_axs = axs # Create a PolyCollection for the axis backgrounds + sel_pos = pos[[i[0] for i in iter_ch]] verts = np.transpose( [ - pos[:, :2], - pos[:, :2] + pos[:, 2:] * [1, 0], - pos[:, :2] + pos[:, 2:], - pos[:, :2] + pos[:, 2:] * [0, 1], + sel_pos[:, :2], + sel_pos[:, :2] + sel_pos[:, 2:] * [1, 0], + sel_pos[:, :2] + sel_pos[:, 2:], + sel_pos[:, :2] + sel_pos[:, 2:] * [0, 1], ], [1, 0, 2], ) - if not img: - under_ax.add_collection( - collections.PolyCollection( - verts, - facecolor=axis_facecolor, - edgecolor=axis_spinecolor, - linewidth=1.0, + if not img: # Not needed for image plots. + collection = collections.PolyCollection( + verts, + facecolor=axis_facecolor, + edgecolor=axis_spinecolor, + linewidth=1.0, + ) + under_ax.add_collection(collection) + + if select: + # Configure the lasso-selection tool + fig.lasso = SelectFromCollection( + ax=under_ax, + collection=collection, + names=shown_ch_names, + alpha_nonselected=0, + alpha_selected=1, + linewidth_nonselected=0, + linewidth_selected=0.7, ) - ) # Not needed for image plots. + + def on_select(): + publish(fig, ChannelsSelect(ch_names=fig.lasso.selection)) + + def on_channels_select(event): + selection_inds = np.flatnonzero( + np.isin(shown_ch_names, event.ch_names) + ) + fig.lasso.select_many(selection_inds) + + fig.lasso.callbacks.append(on_select) + subscribe(fig, "channels_select", on_channels_select) for ax in axs: yield ax, ax._mne_ch_idx @@ -270,6 +308,7 @@ def _plot_topo( unified=False, img=False, axes=None, + select=False, ): """Plot on sensor layout.""" import matplotlib.pyplot as plt @@ -322,6 +361,7 @@ def _plot_topo( unified=unified, img=img, axes=axes, + select=select, ) for ax, ch_idx in my_topo_plot: @@ -340,8 +380,17 @@ def _plot_topo( def _plot_topo_onpick(event, show_func): """Onpick callback that shows a single channel in a new figure.""" - # make sure that the swipe gesture in OS-X doesn't open many figures orig_ax = event.inaxes + fig = orig_ax.figure + + # If we are doing lasso select, allow it to handle the click instead. + if hasattr(fig, "lasso") and event.key in ["control", "ctrl+shift"]: + return + + # make sure that the swipe gesture in OS-X doesn't open many figures + if fig.canvas._key in ["shift", "alt"]: + return + import matplotlib.pyplot as plt try: @@ -838,9 +887,10 @@ def _plot_evoked_topo( merge_channels=False, legend=True, axes=None, + noise_cov=None, exclude="bads", + select=False, show=True, - noise_cov=None, ): """Plot 2D topography of evoked responses. @@ -912,6 +962,10 @@ def _plot_evoked_topo( exclude : list of str | 'bads' Channels names to exclude from being shown. If 'bads', the bad channels are excluded. By default, exclude is set to 'bads'. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. show : bool Show figure if True. @@ -1091,6 +1145,7 @@ def _plot_evoked_topo( y_label=y_label, unified=True, axes=axes, + select=select, ) add_background_image(fig, fig_background) @@ -1098,7 +1153,10 @@ def _plot_evoked_topo( if legend is not False: legend_loc = 0 if legend is True else legend labels = [e.comment if e.comment else "Unknown" for e in evoked] - handles = fig.axes[0].lines[: len(evoked)] + if select: + handles = fig.axes[0].lines[1 : len(evoked) + 1] + else: + handles = fig.axes[0].lines[: len(evoked)] legend = plt.legend( labels=labels, handles=handles, loc=legend_loc, prop={"size": 10} ) @@ -1157,6 +1215,7 @@ def plot_topo_image_epochs( fig_facecolor="k", fig_background=None, font_color="w", + select=False, show=True, ): """Plot Event Related Potential / Fields image on topographies. @@ -1204,6 +1263,12 @@ def plot_topo_image_epochs( :func:`matplotlib.pyplot.imshow`. Defaults to ``None``. font_color : color The color of tick labels in the colorbar. Defaults to white. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. + + .. versionadded:: 1.10.0 show : bool Whether to show the figure. Defaults to ``True``. @@ -1293,6 +1358,7 @@ def plot_topo_image_epochs( y_label="Epoch", unified=True, img=True, + select=select, ) add_background_image(fig, fig_background) plt_show(show) diff --git a/mne/viz/ui_events.py b/mne/viz/ui_events.py index 256d5741ad3..b8b3fe29a4d 100644 --- a/mne/viz/ui_events.py +++ b/mne/viz/ui_events.py @@ -212,6 +212,26 @@ class Contours(UIEvent): contours: list[str] +@dataclass +@fill_doc +class ChannelsSelect(UIEvent): + """Indicates that the user has selected one or more channels. + + Parameters + ---------- + ch_names : list of str + The names of the channels that were selected. + + Attributes + ---------- + %(ui_event_name_source)s + ch_names : list of str + The names of the channels that were selected. + """ + + ch_names: list[str] + + def _get_event_channel(fig): """Get the event channel associated with a figure. diff --git a/mne/viz/utils.py b/mne/viz/utils.py index a09da17de7d..f9d64c49ec8 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -58,7 +58,7 @@ warn, ) from ..utils.misc import _identity_function -from .ui_events import ColormapRange, publish, subscribe +from .ui_events import ChannelsSelect, ColormapRange, publish, subscribe _channel_type_prettyprint = { "eeg": "EEG channel", @@ -807,12 +807,12 @@ def _fake_click(fig, ax, point, xform="ax", button=1, kind="press", key=None): ) -def _fake_keypress(fig, key): +def _fake_keypress(fig, key, kind="press"): from matplotlib import backend_bases fig.canvas.callbacks.process( - "key_press_event", - backend_bases.KeyEvent(name="key_press_event", canvas=fig.canvas, key=key), + f"key_{kind}_event", + backend_bases.KeyEvent(name=f"key_{kind}_event", canvas=fig.canvas, key=key), ) @@ -952,7 +952,7 @@ def plot_sensors( Whether to plot the sensors as 3d, topomap or as an interactive sensor selection dialog. Available options ``'topomap'``, ``'3d'``, ``'select'``. If ``'select'``, a set of channels can be selected - interactively by using lasso selector or clicking while holding control + interactively by using lasso selector or clicking while holding the control key. The selected channels are returned along with the figure instance. Defaults to ``'topomap'``. ch_type : None | str @@ -1163,10 +1163,10 @@ def _onpick_sensor(event, fig, ax, pos, ch_names, show_names): if event.mouseevent.inaxes != ax: return - if event.mouseevent.key == "control" and fig.lasso is not None: + if fig.lasso is not None and event.mouseevent.key in ["control", "ctrl+shift"]: + # Add the sensor to the selection instead of showing its name. for ind in event.ind: fig.lasso.select_one(ind) - return if show_names: return # channel names already visible @@ -1272,7 +1272,17 @@ def _plot_sensors_2d( lw=linewidth, ) if kind == "select": - fig.lasso = SelectFromCollection(ax, pts, ch_names) + fig.lasso = SelectFromCollection(ax, pts, names=ch_names) + + def on_select(): + publish(fig, ChannelsSelect(ch_names=fig.lasso.selection)) + + def on_channels_select(event): + selection_inds = np.flatnonzero(np.isin(ch_names, event.ch_names)) + fig.lasso.select_many(selection_inds) + + fig.lasso.callbacks.append(on_select) + subscribe(fig, "channels_select", on_channels_select) else: fig.lasso = None @@ -1595,11 +1605,14 @@ def _update(self): class SelectFromCollection: - """Select channels from a matplotlib collection using ``LassoSelector``. + """Select objects from a matplotlib collection using ``LassoSelector``. - Selected channels are saved in the ``selection`` attribute. This tool - highlights selected points by fading other points out (i.e., reducing their - alpha values). + The names of the selected objects are saved in the ``selection`` attribute. + This tool highlights selected objects by fading other objects out (i.e., + reducing their alpha values). + + Holding down the Control key will add to the current selection, and holding down + Control+Shift will remove from the current selection. Parameters ---------- @@ -1607,112 +1620,144 @@ class SelectFromCollection: Axes to interact with. collection : instance of matplotlib collection Collection you want to select from. - alpha_other : 0 <= float <= 1 - To highlight a selection, this tool sets all selected points to an - alpha value of 1 and non-selected points to ``alpha_other``. - Defaults to 0.3. - linewidth_other : float - Linewidth to use for non-selected sensors. Default is 1. + names : list of str + The names of the object. The selection is returned as a subset of these names. + alpha_selected : float + Alpha for selected objects (0=tranparant, 1=opaque). + alpha_nonselected : float + Alpha for non-selected objects (0=tranparant, 1=opaque). + linewidth_selected : float + Linewidth for the borders of selected objects. + linewidth_nonselected : float + Linewidth for the borders of non-selected objects. Notes ----- - This tool selects collection objects based on their *origins* - (i.e., ``offsets``). Calls all callbacks in self.callbacks when selection - is ready. + This tool selects collection objects which bounding boxes intersect with a lasso + path. Calls all callbacks in self.callbacks when selection is ready. """ def __init__( self, ax, collection, - ch_names, - alpha_other=0.5, - linewidth_other=0.5, + *, + names, alpha_selected=1, + alpha_nonselected=0.5, linewidth_selected=1, + linewidth_nonselected=0.5, + verbose=None, ): from matplotlib.widgets import LassoSelector + self.fig = ax.figure self.canvas = ax.figure.canvas self.collection = collection - self.ch_names = ch_names - self.alpha_other = alpha_other - self.linewidth_other = linewidth_other + self.names = names self.alpha_selected = alpha_selected + self.alpha_nonselected = alpha_nonselected self.linewidth_selected = linewidth_selected + self.linewidth_nonselected = linewidth_nonselected - self.xys = collection.get_offsets() - self.Npts = len(self.xys) + from matplotlib.collections import PolyCollection + from matplotlib.path import Path - # Ensure that we have separate colors for each object + if isinstance(collection, PolyCollection): + self.paths = collection.get_paths() + else: + self.paths = [Path([point]) for point in collection.get_offsets()] + self.Npts = len(self.paths) + if self.Npts != len(names): + raise ValueError( + f"Number of names ({len(names)}) does not match the number of objects " + f"in the collection ({self.Npts})." + ) + + # Ensure that we have colors for each object. self.fc = collection.get_facecolors() self.ec = collection.get_edgecolors() - self.lw = collection.get_linewidths() if len(self.fc) == 0: raise ValueError("Collection must have a facecolor") elif len(self.fc) == 1: self.fc = np.tile(self.fc, self.Npts).reshape(self.Npts, -1) + if len(self.ec) == 0: + self.ec = np.zeros((self.Npts, 4)) # all black + elif len(self.ec) == 1: self.ec = np.tile(self.ec, self.Npts).reshape(self.Npts, -1) - self.fc[:, -1] = self.alpha_other # deselect in the beginning - self.ec[:, -1] = self.alpha_other - self.lw = np.full(self.Npts, self.linewidth_other) + self.lw = np.full(self.Npts, float(self.linewidth_nonselected)) + # Initialize the lasso selector self.lasso = LassoSelector( ax, onselect=self.on_select, props=dict(color="red", linewidth=0.5) ) self.selection = list() + self.selection_inds = np.array([], dtype="int") self.callbacks = list() + # Deselect everything in the beginning. + self.style_objects() + + # For backwards compatibility + @property + def ch_names(self): + return self.names + + def notify(self): + """Notify listeners that a selection has been made.""" + logger.info(f"Selected channels: {self.selection}") + for callback in self.callbacks: + callback() + def on_select(self, verts): """Select a subset from the collection.""" from matplotlib.path import Path - if len(verts) <= 3: # Seems to be a good way to exclude single clicks. + # Don't respond to single clicks without extra keys being hold down. + # Figures like plot_evoked_topo want to do something else with them. + if len(verts) <= 3 and self.canvas._key not in ["control", "ctrl+shift"]: return path = Path(verts) - inds = np.nonzero([path.contains_point(xy) for xy in self.xys])[0] + inds = np.nonzero([path.intersects_path(p) for p in self.paths])[0] if self.canvas._key == "control": # Appending selection. - sels = [np.where(self.ch_names == c)[0][0] for c in self.selection] - inters = set(inds) - set(sels) - inds = list(inters.union(set(sels) - set(inds))) - - self.selection[:] = np.array(self.ch_names)[inds].tolist() - self.style_sensors(inds) + self.selection_inds = np.union1d(self.selection_inds, inds).astype("int") + elif self.canvas._key == "ctrl+shift": + self.selection_inds = np.setdiff1d(self.selection_inds, inds).astype("int") + else: + self.selection_inds = inds + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() self.notify() def select_one(self, ind): """Select or deselect one sensor.""" - ch_name = self.ch_names[ind] - if ch_name in self.selection: - sel_ind = self.selection.index(ch_name) - self.selection.pop(sel_ind) + if self.canvas._key == "control": + self.selection_inds = np.union1d(self.selection_inds, [ind]) + elif self.canvas._key == "ctrl+shift": + self.selection_inds = np.setdiff1d(self.selection_inds, [ind]) else: - self.selection.append(ch_name) - inds = np.isin(self.ch_names, self.selection).nonzero()[0] - self.style_sensors(inds) + return # don't notify() + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() self.notify() - def notify(self): - """Notify listeners that a selection has been made.""" - for callback in self.callbacks: - callback() - def select_many(self, inds): """Select many sensors using indices (for predefined selections).""" - self.selection[:] = np.array(self.ch_names)[inds].tolist() - self.style_sensors(inds) + self.selection_inds = inds + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() - def style_sensors(self, inds): + def style_objects(self): """Style selected sensors as "active".""" # reset - self.fc[:, -1] = self.alpha_other - self.ec[:, -1] = self.alpha_other / 2 - self.lw[:] = self.linewidth_other + self.fc[:, -1] = self.alpha_nonselected + self.ec[:, -1] = self.alpha_nonselected / 2 + self.lw[:] = self.linewidth_nonselected # style sensors at `inds` - self.fc[inds, -1] = self.alpha_selected - self.ec[inds, -1] = self.alpha_selected - self.lw[inds] = self.linewidth_selected + self.fc[self.selection_inds, -1] = self.alpha_selected + self.ec[self.selection_inds, -1] = self.alpha_selected + self.lw[self.selection_inds] = self.linewidth_selected self.collection.set_facecolors(self.fc) self.collection.set_edgecolors(self.ec) self.collection.set_linewidths(self.lw)