Skip to content

Commit

Permalink
Allow lasso selection sensors in a plot_evoked_topo (#12071)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel McCloy <[email protected]>
  • Loading branch information
3 people authored Jan 23, 2025
1 parent 6028982 commit aca4965
Show file tree
Hide file tree
Showing 12 changed files with 348 additions and 109 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12071.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add new ``select`` parameter to :func:`mne.viz.plot_evoked_topo` and :meth:`mne.Evoked.plot_topo` to toggle lasso selection of sensors, by `Marijn van Vliet`_.
2 changes: 2 additions & 0 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,7 @@ def plot_topo_image(
fig_facecolor="k",
fig_background=None,
font_color="w",
select=False,
show=True,
):
return plot_topo_image_epochs(
Expand All @@ -1371,6 +1372,7 @@ def plot_topo_image(
fig_facecolor=fig_facecolor,
fig_background=fig_background,
font_color=font_color,
select=select,
show=show,
)

Expand Down
2 changes: 2 additions & 0 deletions mne/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ def plot_topo(
background_color="w",
noise_cov=None,
exclude="bads",
select=False,
show=True,
):
""".
Expand All @@ -639,6 +640,7 @@ def plot_topo(
background_color=background_color,
noise_cov=noise_cov,
exclude=exclude,
select=select,
show=show,
)

Expand Down
6 changes: 3 additions & 3 deletions mne/viz/_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,11 +500,11 @@ def _create_ch_location_fig(self, pick):
show=False,
)
# highlight desired channel & disable interactivity
inds = np.isin(fig.lasso.ch_names, [ch_name])
fig.lasso.selection_inds = np.isin(fig.lasso.names, [ch_name])
fig.lasso.disconnect()
fig.lasso.alpha_other = 0.3
fig.lasso.alpha_nonselected = 0.3
fig.lasso.linewidth_selected = 3
fig.lasso.style_sensors(inds)
fig.lasso.style_objects()

return fig

Expand Down
2 changes: 1 addition & 1 deletion mne/viz/_mpl_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,7 +1536,7 @@ def _update_selection(self):
def _update_highlighted_sensors(self):
"""Update the sensor plot to show what is selected."""
inds = np.isin(
self.mne.fig_selection.lasso.ch_names, self.mne.ch_names[self.mne.picks]
self.mne.fig_selection.lasso.names, self.mne.ch_names[self.mne.picks]
).nonzero()[0]
self.mne.fig_selection.lasso.select_many(inds)

Expand Down
13 changes: 12 additions & 1 deletion mne/viz/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,7 @@ def plot_evoked_topo(
background_color="w",
noise_cov=None,
exclude="bads",
select=False,
show=True,
):
"""Plot 2D topography of evoked responses.
Expand Down Expand Up @@ -1218,6 +1219,15 @@ def plot_evoked_topo(
exclude : list of str | ``'bads'``
Channels names to exclude from the plot. If ``'bads'``, the
bad channels are excluded. By default, exclude is set to ``'bads'``.
select : bool
Whether to enable the lasso-selection tool to enable the user to select
channels. The selected channels will be available in
``fig.lasso.selection``.
.. versionadded:: 1.10.0
exclude : list of str | ``'bads'``
Channels names to exclude from the plot. If ``'bads'``, the
bad channels are excluded. By default, exclude is set to ``'bads'``.
show : bool
Show figure if True.
Expand Down Expand Up @@ -1274,10 +1284,11 @@ def plot_evoked_topo(
font_color=font_color,
merge_channels=merge_grads,
legend=legend,
noise_cov=noise_cov,
axes=axes,
exclude=exclude,
select=select,
show=show,
noise_cov=noise_cov,
)


Expand Down
39 changes: 14 additions & 25 deletions mne/viz/tests/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,36 +1088,25 @@ def test_plot_sensors(raw):
pytest.raises(TypeError, plot_sensors, raw) # needs to be info
pytest.raises(ValueError, plot_sensors, raw.info, kind="sasaasd")
plt.close("all")

# Test lasso selection.
fig, sels = raw.plot_sensors("select", show_names=True)
ax = fig.axes[0]

# Click with no sensors
_fake_click(fig, ax, (0.0, 0.0), xform="data")
_fake_click(fig, ax, (0, 0.0), xform="data", kind="release")
assert fig.lasso.selection == []

# Lasso with 1 sensor (upper left)
_fake_click(fig, ax, (0, 1), xform="ax")
fig.canvas.draw()
assert fig.lasso.selection == []
_fake_click(fig, ax, (0.65, 1), xform="ax", kind="motion")
_fake_click(fig, ax, (0.65, 0.7), xform="ax", kind="motion")
_fake_keypress(fig, "control")
_fake_click(fig, ax, (0, 0.7), xform="ax", kind="release", key="control")
# Lasso a single sensor.
_fake_click(fig, ax, (-0.13, 0.13), xform="data")
_fake_click(fig, ax, (-0.11, 0.13), xform="data", kind="motion")
_fake_click(fig, ax, (-0.11, 0.06), xform="data", kind="motion")
_fake_click(fig, ax, (-0.13, 0.06), xform="data", kind="motion")
_fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="motion")
_fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="release")
assert fig.lasso.selection == ["MEG 0121"]

# check that point appearance changes
fc = fig.lasso.collection.get_facecolors()
ec = fig.lasso.collection.get_edgecolors()
assert (fc[:, -1] == [0.5, 1.0, 0.5]).all()
assert (ec[:, -1] == [0.25, 1.0, 0.25]).all()

_fake_click(fig, ax, (0.7, 1), xform="ax", kind="motion", key="control")
xy = ax.collections[0].get_offsets()
_fake_click(fig, ax, xy[2], xform="data", key="control") # single sel
# Add another sensor with a single click.
_fake_keypress(fig, "control")
_fake_click(fig, ax, (-0.1278, 0.0318), xform="data")
_fake_click(fig, ax, (-0.1278, 0.0318), xform="data", kind="release")
_fake_keypress(fig, "control", kind="release")
assert fig.lasso.selection == ["MEG 0121", "MEG 0131"]
_fake_click(fig, ax, xy[2], xform="data", key="control") # deselect
assert fig.lasso.selection == ["MEG 0121"]
plt.close("all")

raw.info["dev_head_t"] = None # like empty room
Expand Down
36 changes: 35 additions & 1 deletion mne/viz/tests/test_topo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from mne.viz.evoked import _line_plot_onselect
from mne.viz.topo import _imshow_tfr, _plot_update_evoked_topo_proj, iter_topography
from mne.viz.utils import _fake_click
from mne.viz.utils import _fake_click, _fake_keypress

base_dir = Path(__file__).parents[2] / "io" / "tests" / "data"
evoked_fname = base_dir / "test-ave.fif"
Expand Down Expand Up @@ -231,6 +231,16 @@ def test_plot_topo():
break
plt.close("all")

# Test plot_topo with selection of channels enabled.
fig = evoked.plot_topo(select=True)
ax = fig.axes[0]
_fake_click(fig, ax, (0.05, 0.62), xform="data")
_fake_click(fig, ax, (0.2, 0.62), xform="data", kind="motion")
_fake_click(fig, ax, (0.2, 0.7), xform="data", kind="motion")
_fake_click(fig, ax, (0.05, 0.7), xform="data", kind="motion")
_fake_click(fig, ax, (0.05, 0.7), xform="data", kind="release")
assert fig.lasso.selection == ["MEG 0113", "MEG 0112", "MEG 0111"]


def test_plot_topo_nirs(fnirs_evoked):
"""Test plotting of ERP topography for nirs data."""
Expand Down Expand Up @@ -296,6 +306,30 @@ def test_plot_topo_image_epochs():
assert qm_cmap[0] is cmap


def test_plot_topo_select():
"""Test selecting sensors in an ERP topography plot."""
# Show topography
evoked = _get_epochs().average()
fig = plot_evoked_topo(evoked, select=True)
ax = fig.axes[0]

# Lasso select 3 out of the 6 sensors.
_fake_click(fig, ax, (0.05, 0.5), xform="data")
_fake_click(fig, ax, (0.2, 0.5), xform="data", kind="motion")
_fake_click(fig, ax, (0.2, 0.6), xform="data", kind="motion")
_fake_click(fig, ax, (0.05, 0.6), xform="data", kind="motion")
_fake_click(fig, ax, (0.05, 0.5), xform="data", kind="motion")
_fake_click(fig, ax, (0.05, 0.5), xform="data", kind="release")
assert fig.lasso.selection == ["MEG 0132", "MEG 0133", "MEG 0131"]

# Add another sensor with a single click.
_fake_keypress(fig, "control")
_fake_click(fig, ax, (0.11, 0.65), xform="data")
_fake_click(fig, ax, (0.21, 0.65), xform="data", kind="release")
_fake_keypress(fig, "control", kind="release")
assert fig.lasso.selection == ["MEG 0111", "MEG 0132", "MEG 0133", "MEG 0131"]


def test_plot_tfr_topo():
"""Test plotting of TFR data."""
epochs = _get_epochs()
Expand Down
69 changes: 69 additions & 0 deletions mne/viz/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mne.viz import ClickableImage, add_background_image, mne_analyze_colormap
from mne.viz.ui_events import ColormapRange, link, subscribe
from mne.viz.utils import (
SelectFromCollection,
_compute_scalings,
_fake_click,
_fake_keypress,
Expand Down Expand Up @@ -274,3 +275,71 @@ def callback(event):
cmap_new1 = fig.axes[0].CB.mappable.get_cmap().name
cmap_new2 = fig2.axes[0].CB.mappable.get_cmap().name
assert cmap_new1 == cmap_new2 == cmap_want != cmap_old


def test_select_from_collection():
"""Test the lasso selector for matplotlib figures."""
fig, ax = plt.subplots()
collection = ax.scatter([1, 2, 2, 1], [1, 1, 0, 0], color="black", edgecolor="red")
ax.set_xlim(-1, 4)
ax.set_ylim(-1, 2)
lasso = SelectFromCollection(ax, collection, names=["A", "B", "C", "D"])
assert lasso.selection == []

# Make a selection with no patches inside of it.
_fake_click(fig, ax, (0, 0), xform="data")
_fake_click(fig, ax, (0.5, 0), xform="data", kind="motion")
_fake_click(fig, ax, (0.5, 1), xform="data", kind="motion")
_fake_click(fig, ax, (0.5, 1), xform="data", kind="release")
assert lasso.selection == []

# Doing a single click on a patch should not select it.
_fake_click(fig, ax, (1, 1), xform="data")
assert lasso.selection == []

# Make a selection with two patches in it.
_fake_click(fig, ax, (0, 0.5), xform="data")
_fake_click(fig, ax, (3, 0.5), xform="data", kind="motion")
_fake_click(fig, ax, (3, 1.5), xform="data", kind="motion")
_fake_click(fig, ax, (0, 1.5), xform="data", kind="motion")
_fake_click(fig, ax, (0, 0.5), xform="data", kind="motion")
_fake_click(fig, ax, (0, 0.5), xform="data", kind="release")
assert lasso.selection == ["A", "B"]

# Use Control key to lasso an additional patch.
_fake_keypress(fig, "control")
_fake_click(fig, ax, (0.5, -0.5), xform="data")
_fake_click(fig, ax, (1.5, -0.5), xform="data", kind="motion")
_fake_click(fig, ax, (1.5, 0.5), xform="data", kind="motion")
_fake_click(fig, ax, (0.5, 0.5), xform="data", kind="motion")
_fake_click(fig, ax, (0.5, 0.5), xform="data", kind="release")
_fake_keypress(fig, "control", kind="release")
assert lasso.selection == ["A", "B", "D"]

# Use CTRL+SHIFT to remove a patch.
_fake_keypress(fig, "ctrl+shift")
_fake_click(fig, ax, (0.5, 0.5), xform="data")
_fake_click(fig, ax, (1.5, 0.5), xform="data", kind="motion")
_fake_click(fig, ax, (1.5, 1.5), xform="data", kind="motion")
_fake_click(fig, ax, (0.5, 1.5), xform="data", kind="motion")
_fake_click(fig, ax, (0.5, 1.5), xform="data", kind="release")
_fake_keypress(fig, "ctrl+shift", kind="release")
assert lasso.selection == ["B", "D"]

# Check that the two selected patches have a different appearance.
fc = lasso.collection.get_facecolors()
ec = lasso.collection.get_edgecolors()
assert (fc[:, -1] == [0.5, 1.0, 0.5, 1.0]).all()
assert (ec[:, -1] == [0.25, 1.0, 0.25, 1.0]).all()

# Test adding and removing single channels.
lasso.select_one(2) # should not do anything without modifier keys
assert lasso.selection == ["B", "D"]
_fake_keypress(fig, "control")
lasso.select_one(2) # add to selection
_fake_keypress(fig, "control", kind="release")
assert lasso.selection == ["B", "C", "D"]
_fake_keypress(fig, "ctrl+shift")
lasso.select_one(1) # remove from selection
assert lasso.selection == ["C", "D"]
_fake_keypress(fig, "ctrl+shift", kind="release")
Loading

0 comments on commit aca4965

Please sign in to comment.