Skip to content

Commit

Permalink
Additional option to flatten TMS artifact (#6915)
Browse files Browse the repository at this point in the history
Co-authored-by: Mainak Jas <[email protected]>
Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
3 people authored Dec 16, 2024
1 parent f3a7fde commit 610ec2a
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 16 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ jobs:
- run: ./tools/github_actions_dependencies.sh
# Minimal commands on Linux (macOS stalls)
- run: ./tools/get_minimal_commands.sh
if: ${{ startswith(matrix.os, 'ubuntu') }}
if: startswith(matrix.os, 'ubuntu') && matrix.kind != 'minimal' && matrix.kind != 'old'
- run: ./tools/github_actions_infos.sh
# Check Qt
- run: ./tools/check_qt_import.sh $MNE_QT_BACKEND
if: ${{ env.MNE_QT_BACKEND != '' }}
if: env.MNE_QT_BACKEND != ''
- name: Run tests with no testing data
run: MNE_SKIP_TESTING_DATASET_TESTS=true pytest -m "not (ultraslowtest or pgtest)" --tb=short --cov=mne --cov-report xml -vv -rfE mne/
if: matrix.kind == 'minimal'
Expand Down
1 change: 1 addition & 0 deletions doc/changes/devel/6915.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add option to :func:`mne.preprocessing.fix_stim_artifact` to use baseline average to flatten TMS pulse artifact by `Fahimeh Mamashli`_ and `Padma Sundaram`_ and `Mohammad Daneshzand`_.
5 changes: 2 additions & 3 deletions examples/datasets/brainstorm_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
=====================================
Here we compute the evoked from raw for the Brainstorm
tutorial dataset. For comparison, see :footcite:`TadelEtAl2011` and:
https://neuroimage.usc.edu/brainstorm/Tutorials/MedianNerveCtf
tutorial dataset. For comparison, see :footcite:`TadelEtAl2011` and
https://neuroimage.usc.edu/brainstorm/Tutorials/MedianNerveCtf.
"""

# Authors: Mainak Jas <[email protected]>
Expand Down
65 changes: 56 additions & 9 deletions mne/preprocessing/stim.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..event import find_events
from ..evoked import Evoked
from ..io import BaseRaw
from ..utils import _check_option, _check_preload, fill_doc
from ..utils import _check_option, _check_preload, _validate_type, fill_doc


def _get_window(start, end):
Expand All @@ -20,7 +20,9 @@ def _get_window(start, end):
return window


def _fix_artifact(data, window, picks, first_samp, last_samp, mode):
def _fix_artifact(
data, window, picks, first_samp, last_samp, base_tmin, base_tmax, mode
):
"""Modify original data by using parameter data."""
if mode == "linear":
x = np.array([first_samp, last_samp])
Expand All @@ -32,6 +34,10 @@ def _fix_artifact(data, window, picks, first_samp, last_samp, mode):
data[picks, first_samp:last_samp] = (
data[picks, first_samp:last_samp] * window[np.newaxis, :]
)
if mode == "constant":
data[picks, first_samp:last_samp] = data[picks, base_tmin:base_tmax].mean(
axis=1
)[:, None]


@fill_doc
Expand All @@ -41,6 +47,8 @@ def fix_stim_artifact(
event_id=None,
tmin=0.0,
tmax=0.01,
*,
baseline=None,
mode="linear",
stim_channel=None,
picks=None,
Expand All @@ -63,10 +71,23 @@ def fix_stim_artifact(
Start time of the interpolation window in seconds.
tmax : float
End time of the interpolation window in seconds.
mode : 'linear' | 'window'
baseline : None | tuple, shape (2,)
The baseline to use when ``mode='constant'``, in which case it
must be non-None.
.. versionadded:: 1.8
mode : 'linear' | 'window' | 'constant'
Way to fill the artifacted time interval.
'linear' does linear interpolation
'window' applies a (1 - hanning) window.
``"linear"``
Does linear interpolation.
``"window"``
Applies a ``(1 - hanning)`` window.
``"constant"``
Uses baseline average. baseline parameter must be provided.
.. versionchanged:: 1.8
Added the ``"constant"`` mode.
stim_channel : str | None
Stim channel to use.
%(picks_all_data)s
Expand All @@ -76,9 +97,22 @@ def fix_stim_artifact(
inst : instance of Raw or Evoked or Epochs
Instance with modified data.
"""
_check_option("mode", mode, ["linear", "window"])
_check_option("mode", mode, ["linear", "window", "constant"])
s_start = int(np.ceil(inst.info["sfreq"] * tmin))
s_end = int(np.ceil(inst.info["sfreq"] * tmax))
if mode == "constant":
_validate_type(
baseline, (tuple, list), "baseline", extra="when mode='constant'"
)
_check_option("len(baseline)", len(baseline), [2])
for bi, b in enumerate(baseline):
_validate_type(
b, "numeric", f"baseline[{bi}]", extra="when mode='constant'"
)
b_start = int(np.ceil(inst.info["sfreq"] * baseline[0]))
b_end = int(np.ceil(inst.info["sfreq"] * baseline[1]))
else:
b_start = b_end = np.nan
if (mode == "window") and (s_end - s_start) < 4:
raise ValueError(
'Time range is too short. Use a larger interval or set mode to "linear".'
Expand All @@ -104,7 +138,11 @@ def fix_stim_artifact(
for event_idx in event_start:
first_samp = int(event_idx) - inst.first_samp + s_start
last_samp = int(event_idx) - inst.first_samp + s_end
_fix_artifact(data, window, picks, first_samp, last_samp, mode)
base_t1 = int(event_idx) - inst.first_samp + b_start
base_t2 = int(event_idx) - inst.first_samp + b_end
_fix_artifact(
data, window, picks, first_samp, last_samp, base_t1, base_t2, mode
)
elif isinstance(inst, BaseEpochs):
if inst.reject is not None:
raise RuntimeError(
Expand All @@ -114,14 +152,23 @@ def fix_stim_artifact(
first_samp = s_start - e_start
last_samp = s_end - e_start
data = inst._data
base_t1 = b_start - e_start
base_t2 = b_end - e_start
for epoch in data:
_fix_artifact(epoch, window, picks, first_samp, last_samp, mode)
_fix_artifact(
epoch, window, picks, first_samp, last_samp, base_t1, base_t2, mode
)

elif isinstance(inst, Evoked):
first_samp = s_start - inst.first
last_samp = s_end - inst.first
data = inst.data
_fix_artifact(data, window, picks, first_samp, last_samp, mode)
base_t1 = b_start - inst.first
base_t2 = b_end - inst.first

_fix_artifact(
data, window, picks, first_samp, last_samp, base_t1, base_t2, mode
)

else:
raise TypeError(f"Not a Raw or Epochs or Evoked (got {type(inst)}).")
Expand Down
35 changes: 35 additions & 0 deletions mne/preprocessing/tests/test_stim.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ def test_fix_stim_artifact():
data_from_epochs_fix = epochs.get_data(copy=False)[:, :, tmin_samp:tmax_samp]
assert not np.all(data_from_epochs_fix != 0)

baseline = (-0.1, -0.05)
epochs = fix_stim_artifact(
epochs, tmin=tmin, tmax=tmax, baseline=baseline, mode="constant"
)
b_start = int(np.ceil(epochs.info["sfreq"] * baseline[0]))
b_end = int(np.ceil(epochs.info["sfreq"] * baseline[1]))
base_t1 = b_start - e_start
base_t2 = b_end - e_start
baseline_mean = epochs.get_data()[:, :, base_t1:base_t2].mean(axis=2)[0][0]
data = epochs.get_data()[:, :, tmin_samp:tmax_samp]
assert data[0][0][0] == baseline_mean

# use window before stimulus in raw
event_idx = np.where(events[:, 2] == 1)[0][0]
tmin, tmax = -0.045, -0.015
Expand All @@ -81,8 +93,22 @@ def test_fix_stim_artifact():
raw, events, event_id=1, tmin=tmin, tmax=tmax, mode="window"
)
data, times = raw[:, (tidx + tmin_samp) : (tidx + tmax_samp)]

assert np.all(data) == 0.0

raw = fix_stim_artifact(
raw,
events,
event_id=1,
tmin=tmin,
tmax=tmax,
baseline=baseline,
mode="constant",
)
data, times = raw[:, (tidx + tmin_samp) : (tidx + tmax_samp)]
baseline_mean, _ = raw[:, (tidx + b_start) : (tidx + b_end)]
assert baseline_mean.mean(axis=1)[0] == data[0][0]

# get epochs from raw with fixed data
tmin, tmax, event_id = -0.2, 0.5, 1
epochs = Epochs(
Expand Down Expand Up @@ -117,3 +143,12 @@ def test_fix_stim_artifact():
evoked = fix_stim_artifact(evoked, tmin=tmin, tmax=tmax, mode="window")
data = evoked.data[:, tmin_samp:tmax_samp]
assert np.all(data) == 0.0

evoked = fix_stim_artifact(
evoked, tmin=tmin, tmax=tmax, baseline=baseline, mode="constant"
)
base_t1 = int(baseline[0] * evoked.info["sfreq"]) - evoked.first
base_t2 = int(baseline[1] * evoked.info["sfreq"]) - evoked.first
data = evoked.data[:, tmin_samp:tmax_samp]
baseline_mean = evoked.data[:, base_t1:base_t2].mean(axis=1)[0]
assert data[0][0] == baseline_mean
10 changes: 8 additions & 2 deletions tools/install_pre_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@ echo "PyQt6 and scientific-python-nightly-wheels dependencies"
python -m pip install $STD_ARGS pip setuptools packaging \
threadpoolctl cycler fonttools kiwisolver pyparsing pillow python-dateutil \
patsy pytz tzdata nibabel tqdm trx-python joblib numexpr "$QT_BINDING" \
py-cpuinfo blosc2
py-cpuinfo blosc2 hatchling
echo "NumPy/SciPy/pandas etc."
python -m pip uninstall -yq numpy
python -m pip install $STD_ARGS --only-binary ":all:" --default-timeout=60 \
--index-url "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" \
"numpy>=2.1.0.dev0" "scikit-learn>=1.6.dev0" "scipy>=1.15.0.dev0" \
"statsmodels>=0.15.0.dev0" "pandas>=3.0.0.dev0" "matplotlib>=3.10.0.dev0" \
"pandas>=3.0.0.dev0" "matplotlib>=3.10.0.dev0" \
"h5py>=3.12.1" "dipy>=1.10.0.dev0" "pyarrow>=19.0.0.dev0" "tables>=3.10.2.dev0"

# statsmodels requires formulaic@main so we need to use --extra-index-url
echo "statsmodels"
python -m pip install $STD_ARGS --only-binary ":all:" \
--extra-index-url "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" \
"statsmodels>=0.15.0.dev0"

# No Numba because it forces an old NumPy version

echo "pymatreader"
Expand Down

0 comments on commit 610ec2a

Please sign in to comment.