Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add interpolate_to method #13044

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
824457d
working implem
antoinecollas Dec 30, 2024
5e5f16b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 30, 2024
4236efb
pre-commit
antoinecollas Dec 30, 2024
36981ca
Merge branch 'interpolate_to' of https://github.com/antoinecollas/mne…
antoinecollas Dec 30, 2024
e7cefd8
Merge branch 'main' into interpolate_to
antoinecollas Dec 30, 2024
3a0383a
minor changes
antoinecollas Dec 30, 2024
3ea4cad
fix nested imports
antoinecollas Dec 30, 2024
29f7ce4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 30, 2024
8a05713
fix comment
antoinecollas Dec 30, 2024
de0ddc4
add test
antoinecollas Dec 30, 2024
254fdac
fix docstring
antoinecollas Dec 30, 2024
4bf3f1d
fix docstring
antoinecollas Dec 30, 2024
2f27599
rm plt.tight_layout
antoinecollas Dec 30, 2024
872ba1d
taller figure
antoinecollas Dec 30, 2024
045496c
[autofix.ci] apply automated fixes
autofix-ci[bot] Jan 1, 2025
568656d
Merge branch 'main' into interpolate_to
antoinecollas Jan 1, 2025
6fb3ae5
fix figure layout
antoinecollas Jan 3, 2025
dd093a8
improve test
antoinecollas Jan 3, 2025
9551016
simplify getting original data
antoinecollas Jan 3, 2025
fa326ad
simplify setting interpolated data
antoinecollas Jan 3, 2025
1a5ca89
merge two lines
antoinecollas Jan 3, 2025
4333f08
add spline method
antoinecollas Jan 6, 2025
4e25616
add splive vs mne to doc
antoinecollas Jan 6, 2025
a249321
keep all modalities in test
antoinecollas Jan 6, 2025
fb8804f
use self.info instead of old_info
antoinecollas Jan 6, 2025
7ff7f51
fix info when MNE interpolation
antoinecollas Jan 6, 2025
1334fb3
fix origin in spline method
antoinecollas Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -2504,3 +2504,11 @@ @article{OyamaEtAl2015
year = {2015},
pages = {24--36},
}

@inproceedings{MellotEtAl2024,
title = {Physics-informed and Unsupervised Riemannian Domain Adaptation for Machine Learning on Heterogeneous EEG Datasets},
author = {Mellot, Apolline and Collas, Antoine and Chevallier, Sylvain and Engemann, Denis and Gramfort, Alexandre},
booktitle = {Proceedings of the 32nd European Signal Processing Conference (EUSIPCO)},
year = {2024},
address = {Lyon, France}
}
76 changes: 76 additions & 0 deletions examples/preprocessing/interpolate_to.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
.. _ex-interpolate-to-any-montage:

======================================================
Interpolate EEG data to any montage
======================================================

This example demonstrates how to interpolate EEG channels to match a given montage.
This can be useful for standardizing
EEG channel layouts across different datasets (see :footcite:`MellotEtAl2024`).

- Using the field interpolation for EEG data.
- Using the target montage "biosemi16".

In this example, the data from the original EEG channels will be
interpolated onto the positions defined by the "biosemi16" montage.
"""

# Authors: Antoine Collas <[email protected]>
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import matplotlib.pyplot as plt

import mne
from mne.channels import make_standard_montage
from mne.datasets import sample

print(__doc__)

# %%
# Load EEG data
data_path = sample.data_path()
eeg_file_path = data_path / "MEG" / "sample" / "sample_audvis-ave.fif"
evoked = mne.read_evokeds(eeg_file_path, condition="Left Auditory", baseline=(None, 0))

# Select only EEG channels
evoked.pick("eeg")

# Plot the original EEG layout
evoked.plot(exclude=[], picks="eeg")

# %%
# Define the target montage
standard_montage = make_standard_montage("biosemi16")

# %%
# Use interpolate_to to project EEG data to the standard montage
evoked_interpolated_spline = evoked.copy().interpolate_to(
standard_montage, method="spline"
)

# Plot the interpolated EEG layout
evoked_interpolated_spline.plot(exclude=[], picks="eeg")

# %%
# Use interpolate_to to project EEG data to the standard montage
evoked_interpolated_mne = evoked.copy().interpolate_to(standard_montage, method="MNE")

# Plot the interpolated EEG layout
evoked_interpolated_mne.plot(exclude=[], picks="eeg")

# %%
# Comparing before and after interpolation
fig, axs = plt.subplots(3, 1, figsize=(8, 6), constrained_layout=True)
evoked.plot(exclude=[], picks="eeg", axes=axs[0], show=False)
axs[0].set_title("Original EEG Layout")
evoked_interpolated_spline.plot(exclude=[], picks="eeg", axes=axs[1], show=False)
axs[1].set_title("Interpolated to Standard 1020 Montage using spline interpolation")
evoked_interpolated_mne.plot(exclude=[], picks="eeg", axes=axs[2], show=False)
axs[2].set_title("Interpolated to Standard 1020 Montage using MNE interpolation")

# %%
# References
# ----------
# .. footbibliography::
105 changes: 105 additions & 0 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,111 @@ def interpolate_bads(

return self

def interpolate_to(self, montage, origin="auto", method="spline", reg=0.0):
"""Interpolate EEG data onto a new montage.

Parameters
----------
montage : DigMontage
The target montage containing channel positions to interpolate onto.
origin : array-like, shape (3,) | str
Origin of the sphere in the head coordinate frame and in meters.
Can be ``'auto'`` (default), which means a head-digitization-based
origin fit.
method : str
Method to use for EEG channels.
Supported methods are 'spline' (default) and 'MNE'.

.. warning::
Be careful, only EEG channels are interpolated. Other channel types are
not interpolated.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this warning applies to the entire function, not the method parameter specifically. Perhaps move it to the main text, so above the Parameters line.


reg : float
The regularization parameter for the interpolation method (if applicable).
antoinecollas marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The regularization parameter for the interpolation method (if applicable).
The regularization parameter for the interpolation method (only used when the method is 'spline').


Returns
-------
inst : instance of Raw, Epochs, or Evoked
The instance with updated channel locations and data.

Notes
-----
This method is useful for standardizing EEG layouts across datasets.

.. versionadded:: 1.10.0
"""
from ..forward._field_interpolation import _map_meg_or_eeg_channels
from .interpolation import _make_interpolation_matrix

# Get target positions from the montage
ch_pos = montage.get_positions()["ch_pos"]
target_ch_names = list(ch_pos.keys())
if len(target_ch_names) == 0:
raise ValueError(
"The provided montage does not contain any channel positions."
)
antoinecollas marked this conversation as resolved.
Show resolved Hide resolved

# Check the method is valid
_check_option("method", method, ["spline", "MNE"])

# Ensure data is loaded
_check_preload(self, "interpolation")

# Extract positions and data for EEG channels
picks_from = pick_types(self.info, meg=False, eeg=True, exclude=[])
if len(picks_from) == 0:
raise ValueError("No EEG channels available for interpolation.")

# Create a new info structure
sfreq = self.info["sfreq"]
ch_types = ["eeg"] * len(target_ch_names)
new_info = create_info(ch_names=target_ch_names, sfreq=sfreq, ch_types=ch_types)
new_info.set_montage(montage)

# Compute mapping from current montage to target montage
if method == "spline":
# pos_from = np.array(
# [self.info["chs"][idx]["loc"][:3] for idx in picks_from]
# )

origin = _check_origin(origin, self.info)
pos_from = self.info._get_channel_positions(picks_from)
pos_from = pos_from - origin
pos_to = np.stack(list(ch_pos.values()), axis=0)

def _check_pos_sphere(pos):
distance = np.linalg.norm(pos, axis=-1)
distance = np.mean(distance / np.mean(distance))
if np.abs(1.0 - distance) > 0.1:
warn(
"Your spherical fit is poor, interpolation results are "
"likely to be inaccurate."
)

_check_pos_sphere(pos_from)
_check_pos_sphere(pos_to)

mapping = _make_interpolation_matrix(pos_from, pos_to, alpha=reg)

elif method == "MNE":
info_eeg = pick_info(self.info, picks_from)
mapping = _map_meg_or_eeg_channels(
info_eeg, new_info, mode="accurate", origin="auto"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be:

Suggested change
info_eeg, new_info, mode="accurate", origin="auto"
info_eeg, new_info, mode="accurate", origin=origin

)

# Apply the interpolation mapping
data_orig = self.get_data(picks=picks_from)
data_interp = mapping.dot(data_orig)

# Update bad channels
new_info["bads"] = [ch for ch in self.info["bads"] if ch in target_ch_names]

# Update the instance's info and data
self.info = new_info
self._data = data_interp

return self


@verbose
def rename_channels(info, mapping, allow_duplicates=False, *, verbose=None):
Expand Down
38 changes: 37 additions & 1 deletion mne/channels/tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mne import Epochs, pick_channels, pick_types, read_events
from mne._fiff.constants import FIFF
from mne._fiff.proj import _has_eeg_average_ref_proj
from mne.channels import make_dig_montage
from mne.channels import make_dig_montage, make_standard_montage
from mne.channels.interpolation import _make_interpolation_matrix
from mne.datasets import testing
from mne.io import RawArray, read_raw_ctf, read_raw_fif, read_raw_nirx
Expand Down Expand Up @@ -439,3 +439,39 @@ def test_method_str():
raw.interpolate_bads(method="spline")
raw.pick("eeg", exclude=())
raw.interpolate_bads(method="spline")


@pytest.mark.parametrize("montage_name", ["biosemi16", "standard_1020"])
@pytest.mark.parametrize("method", ["spline", "MNE"])
def test_interpolate_to_eeg(montage_name, method):
"""Test the interpolate_to method for EEG."""
# Load EEG data
raw, epochs_eeg = _load_data("eeg")
epochs_eeg = epochs_eeg.copy()
assert not _has_eeg_average_ref_proj(epochs_eeg.info)

# Load data
raw.load_data()

# Create a target montage
montage = make_standard_montage(montage_name)

# Copy the raw object and apply interpolation
raw_interpolated = raw.copy().interpolate_to(montage, method=method)

# Check if channel names match the target montage
assert set(raw_interpolated.info["ch_names"]) == set(montage.ch_names)

# Check if the data was interpolated correctly
assert raw_interpolated.get_data().shape == (len(montage.ch_names), raw.n_times)

# Ensure original data is not altered
assert raw.info["ch_names"] != raw_interpolated.info["ch_names"]
assert raw.get_data().shape == (len(raw.info["ch_names"]), raw.n_times)

# Validate that bad channels are carried over
raw.info["bads"] = [raw.info["ch_names"][0]]
raw_interpolated = raw.copy().interpolate_to(montage, method=method)
assert raw_interpolated.info["bads"] == [
ch for ch in raw.info["bads"] if ch in montage.ch_names
]
Loading