diff --git a/.circleci/config.yml b/.circleci/config.yml index 4297dc5fedf..26b9f600e3c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -218,6 +218,9 @@ jobs: - restore_cache: keys: - data-cache-phantom-kit + - restore_cache: + keys: + - data-cache-ds004388 - run: name: Get data # This limit could be increased, but this is helpful for finding slow ones @@ -252,7 +255,7 @@ jobs: name: Check sphinx log for warnings (which are treated as errors) when: always command: | - ! grep "^.* (WARNING|ERROR): .*$" sphinx_log.txt + ! grep "^.*\(WARNING\|ERROR\): " sphinx_log.txt - run: name: Show profiling output when: always @@ -393,6 +396,10 @@ jobs: key: data-cache-phantom-kit paths: - ~/mne_data/MNE-phantom-KIT-data # (1 G) + - save_cache: + key: data-cache-ds004388 + paths: + - ~/mne_data/ds004388 # (1.8 G) linkcheck: diff --git a/doc/api/datasets.rst b/doc/api/datasets.rst index 2b2c92c8654..87730fbd717 100644 --- a/doc/api/datasets.rst +++ b/doc/api/datasets.rst @@ -18,6 +18,7 @@ Datasets brainstorm.bst_auditory.data_path brainstorm.bst_resting.data_path brainstorm.bst_raw.data_path + default_path eegbci.load_data eegbci.standardize fetch_aparc_sub_parcellation diff --git a/doc/api/preprocessing.rst b/doc/api/preprocessing.rst index 86ad3aca910..9fe3f995cc4 100644 --- a/doc/api/preprocessing.rst +++ b/doc/api/preprocessing.rst @@ -116,6 +116,7 @@ Projections: read_ica_eeglab read_fine_calibration write_fine_calibration + apply_pca_obs :py:mod:`mne.preprocessing.nirs`: diff --git a/doc/changes/devel/13037.newfeature.rst b/doc/changes/devel/13037.newfeature.rst new file mode 100644 index 00000000000..3b28e2294ab --- /dev/null +++ b/doc/changes/devel/13037.newfeature.rst @@ -0,0 +1 @@ +Add PCA-OBS preprocessing for the removal of heart-artefacts from EEG or ESG datasets via :func:`mne.preprocessing.apply_pca_obs`, by :newcontrib:`Emma Bailey` and :newcontrib:`Steinn Hauser Magnusson`. diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 3ac0b1cd9c9..eb444c5e594 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -73,6 +73,7 @@ .. _Eberhard Eich: https://github.com/ebeich .. _Eduard Ort: https://github.com/eort .. _Emily Stephen: https://github.com/emilyps14 +.. _Emma Bailey: https://www.cbs.mpg.de/employees/bailey .. _Enrico Varano: https://github.com/enricovara/ .. _Enzo Altamiranda: https://www.linkedin.com/in/enzoalt .. _Eric Larson: https://larsoner.com @@ -284,6 +285,7 @@ .. _Stanislas Chambon: https://github.com/Slasnista .. _Stefan Appelhoff: https://stefanappelhoff.com .. _Stefan Repplinger: https://github.com/stfnrpplngr +.. _Steinn Hauser Magnusson: https://github.com/steinnhauser .. _Steven Bethard: https://github.com/bethard .. _Steven Bierer: https://github.com/neurolaunch .. _Steven Gutstein: https://github.com/smgutstein diff --git a/doc/conf.py b/doc/conf.py index 74f66d8f6ae..f1b771571d6 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -355,6 +355,7 @@ "n_frequencies", "n_tests", "n_samples", + "n_peaks", "n_permutations", "nchan", "n_points", diff --git a/doc/references.bib b/doc/references.bib index a129d2f46a2..e2578ed18f2 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -1335,6 +1335,16 @@ @inproceedings{NdiayeEtAl2016 year = {2016} } +@article{NiazyEtAl2005, + author = {Niazy, R. K. and Beckmann, C.F. and Iannetti, G.D. and Brady, J. M. and Smith, S. M.}, + title = {Removal of FMRI environment artifacts from EEG data using optimal basis sets}, + journal = {NeuroImage}, + year = {2005}, + volume = {28}, + pages = {720-737}, + doi = {10.1016/j.neuroimage.2005.06.067.} +} + @article{NicholsHolmes2002, author = {Nichols, Thomas E. and Holmes, Andrew P.}, doi = {10.1002/hbm.1058}, diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py new file mode 100755 index 00000000000..a6c6bb3c2ba --- /dev/null +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -0,0 +1,196 @@ +""" +.. _ex-pcaobs: + +===================================================================================== +Principal Component Analysis - Optimal Basis Sets (PCA-OBS) removing cardiac artefact +===================================================================================== + +This script shows an example of how to use an adaptation of PCA-OBS +:footcite:`NiazyEtAl2005`. PCA-OBS was originally designed to remove +the ballistocardiographic artefact in simultaneous EEG-fMRI. Here, it +has been adapted to remove the delay between the detected R-peak and the +ballistocardiographic artefact such that the algorithm can be applied to +remove the cardiac artefact in EEG (electroencephalography) and ESG +(electrospinography) data. We will illustrate how it works by applying the +algorithm to ESG data, where the effect of removal is most pronounced. + +See: https://www.biorxiv.org/content/10.1101/2024.09.05.611423v1 +for more details on the dataset and application for ESG data. + +""" + +# Authors: Emma Bailey , +# Steinn Hauser Magnusson +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import glob + +import numpy as np + +# %% +# Download sample subject data from OpenNeuro if you haven't already. +# This will download simultaneous EEG and ESG data from a single run of a +# single participant after median nerve stimulation of the left wrist. +import openneuro +from matplotlib import pyplot as plt + +import mne +from mne import Epochs, events_from_annotations +from mne.io import read_raw_eeglab +from mne.preprocessing import find_ecg_events, fix_stim_artifact + +# add the path where you want the OpenNeuro data downloaded. Each run is ~2GB of data +ds = "ds004388" +target_dir = mne.datasets.default_path() / ds +run_name = "sub-001/eeg/*median_run-03_eeg*.set" +if not glob.glob(str(target_dir / run_name)): + target_dir.mkdir(exist_ok=True) + openneuro.download(dataset=ds, target_dir=target_dir, include=run_name[:-4]) +block_files = glob.glob(str(target_dir / run_name)) +assert len(block_files) == 1 + +# %% +# Define the esg channels (arranged in two patches over the neck and lower back). + +esg_chans = [ + "S35", + "S24", + "S36", + "Iz", + "S17", + "S15", + "S32", + "S22", + "S19", + "S26", + "S28", + "S9", + "S13", + "S11", + "S7", + "SC1", + "S4", + "S18", + "S8", + "S31", + "SC6", + "S12", + "S16", + "S5", + "S30", + "S20", + "S34", + "S21", + "S25", + "L1", + "S29", + "S14", + "S33", + "S3", + "L4", + "S6", + "S23", +] + +# Interpolation window in seconds for ESG data to remove stimulation artefact +tstart_esg = -7e-3 +tmax_esg = 7e-3 + +# Define timing of heartbeat epochs in seconds relative to R-peaks +iv_baseline = [-400e-3, -300e-3] +iv_epoch = [-400e-3, 600e-3] + +# %% +# Next, we perform minimal preprocessing including removing the +# stimulation artefact, downsampling and filtering. + +raw = read_raw_eeglab(block_files[0], verbose="error") +raw.set_channel_types(dict(ECG="ecg")) +# Isolate the ESG channels (include the ECG channel for R-peak detection) +raw.pick(esg_chans + ["ECG"]) +# Trim duration and downsample (from 10kHz) to improve example speed +raw.crop(0, 60).load_data().resample(2000) + +# Find trigger timings to remove the stimulation artefact +events, event_dict = events_from_annotations(raw) +trigger_name = "Median - Stimulation" + +fix_stim_artifact( + raw, + events=events, + event_id=event_dict[trigger_name], + tmin=tstart_esg, + tmax=tmax_esg, + mode="linear", + stim_channel=None, +) + +# %% +# Find ECG events and add to the raw structure as event annotations. + +ecg_events, ch_ecg, average_pulse = find_ecg_events(raw, ch_name="ECG") +ecg_event_samples = np.asarray( + [[ecg_event[0] for ecg_event in ecg_events]] +) # Samples only + +qrs_event_time = [ + x / raw.info["sfreq"] for x in ecg_event_samples.reshape(-1) +] # Divide by sampling rate to make times +duration = np.repeat(0.0, len(ecg_event_samples)) +description = ["qrs"] * len(ecg_event_samples) + +raw.annotations.append( + qrs_event_time, duration, description, ch_names=[esg_chans] * len(qrs_event_time) +) + +# %% +# Create evoked response around the detected R-peaks +# before and after cardiac artefact correction. + +events, event_ids = events_from_annotations(raw) +event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} +epochs = Epochs( + raw, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), +) +evoked_before = epochs.average() + +# Apply function - modifies the data in place. Optionally high-pass filter +# the data before applying PCA-OBS to remove low frequency drifts +raw = mne.preprocessing.apply_pca_obs( + raw, picks=esg_chans, n_jobs=5, qrs_times=raw.times[ecg_event_samples.reshape(-1)] +) + +epochs = Epochs( + raw, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), +) +evoked_after = epochs.average() + +# %% +# Compare evoked responses to assess completeness of artefact removal. + +fig, axes = plt.subplots(1, 1, layout="constrained") +data_before = evoked_before.get_data(units=dict(eeg="uV")).T +data_after = evoked_after.get_data(units=dict(eeg="uV")).T +hs = list() +hs.append(axes.plot(epochs.times, data_before, color="k")[0]) +hs.append(axes.plot(epochs.times, data_after, color="green", label="after")[0]) +axes.set(ylim=[-500, 1000], ylabel="Amplitude (µV)", xlabel="Time (s)") +axes.set(title="ECG artefact removal using PCA-OBS") +axes.legend(hs, ["before", "after"]) +plt.show() + +# %% +# References +# ---------- +# .. footbibliography:: diff --git a/mne/datasets/__init__.pyi b/mne/datasets/__init__.pyi index 44cee84fe7f..2f69a1027e5 100644 --- a/mne/datasets/__init__.pyi +++ b/mne/datasets/__init__.pyi @@ -6,6 +6,7 @@ __all__ = [ "epilepsy_ecog", "erp_core", "eyelink", + "default_path", "fetch_aparc_sub_parcellation", "fetch_dataset", "fetch_fsaverage", @@ -70,6 +71,7 @@ from ._infant import fetch_infant_template from ._phantom.base import fetch_phantom from .utils import ( _download_all_example_data, + default_path, fetch_aparc_sub_parcellation, fetch_hcp_mmp_parcellation, has_dataset, diff --git a/mne/datasets/utils.py b/mne/datasets/utils.py index 452e42cffc7..93aabc0841a 100644 --- a/mne/datasets/utils.py +++ b/mne/datasets/utils.py @@ -2,6 +2,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import glob import importlib import inspect import logging @@ -92,6 +93,22 @@ def _dataset_version(path, name): return version +@verbose +def default_path(*, verbose=None): + """Get the default MNE_DATA path. + + Parameters + ---------- + %(verbose)s + + Returns + ------- + data_path : instance of Path + Path to the default MNE_DATA directory. + """ + return _get_path(None, None, None) + + def _get_path(path, key, name): """Get a dataset path.""" # 1. Input @@ -113,7 +130,8 @@ def _get_path(path, key, name): return path # 4. ~/mne_data (but use a fake home during testing so we don't # unnecessarily create ~/mne_data) - logger.info(f"Using default location ~/mne_data for {name}...") + extra = f" for {name}" if name else "" + logger.info(f"Using default location ~/mne_data{extra}...") path = Path(os.getenv("_MNE_FAKE_HOME_DIR", "~")).expanduser() / "mne_data" if not path.is_dir(): logger.info(f"Creating {path}") @@ -319,6 +337,8 @@ def _download_all_example_data(verbose=True): # # verbose=True by default so we get nice status messages. # Consider adding datasets from here to CircleCI for PR-auto-build + import openneuro + paths = dict() for kind in ( "sample testing misc spm_face somato hf_sef multimodal " @@ -375,6 +395,14 @@ def _download_all_example_data(verbose=True): limo.load_data(subject=1, update_path=True) logger.info("[done limo]") + # for ESG + ds = "ds004388" + target_dir = default_path() / ds + run_name = "sub-001/eeg/*median_run-03_eeg*.set" + if not glob.glob(str(target_dir / run_name)): + target_dir.mkdir(exist_ok=True) + openneuro.download(dataset=ds, target_dir=target_dir, include=run_name[:-4]) + @verbose def fetch_aparc_sub_parcellation(subjects_dir=None, verbose=None): diff --git a/mne/preprocessing/__init__.pyi b/mne/preprocessing/__init__.pyi index 54f1c825c13..c54685dba34 100644 --- a/mne/preprocessing/__init__.pyi +++ b/mne/preprocessing/__init__.pyi @@ -44,6 +44,7 @@ __all__ = [ "realign_raw", "regress_artifact", "write_fine_calibration", + "apply_pca_obs", ] from . import eyetracking, ieeg, nirs from ._annotate_amplitude import annotate_amplitude @@ -56,6 +57,7 @@ from ._fine_cal import ( write_fine_calibration, ) from ._lof import find_bad_channels_lof +from ._pca_obs import apply_pca_obs from ._peak_finder import peak_finder from ._regress import EOGRegression, read_eog_regression, regress_artifact from .artifact_detection import ( diff --git a/mne/preprocessing/_pca_obs.py b/mne/preprocessing/_pca_obs.py new file mode 100755 index 00000000000..be226a73889 --- /dev/null +++ b/mne/preprocessing/_pca_obs.py @@ -0,0 +1,333 @@ +"""Principle Component Analysis Optimal Basis Sets (PCA-OBS).""" + +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import math + +import numpy as np +from scipy.interpolate import PchipInterpolator as pchip +from scipy.signal import detrend + +from ..io.fiff.raw import Raw +from ..utils import _PCA, _validate_type, logger, verbose + + +@verbose +def apply_pca_obs( + raw: Raw, + picks: list[str], + *, + qrs_times: np.ndarray, + n_components: int = 4, + n_jobs: int | None = None, + copy: bool = True, + verbose: bool | str | int | None = None, +) -> Raw: + """ + Apply the PCA-OBS algorithm to picks of a Raw object. + + Uses the optimal basis set (OBS) algorithm from :footcite:`NiazyEtAl2005`. + + Parameters + ---------- + raw : instance of Raw + The raw data to process. + %(picks_all_data_noref)s + qrs_times : ndarray, shape (n_peaks,) + Array of times in the Raw data of detected R-peaks in ECG channel. + n_components : int + Number of PCA components to use to form the OBS (default 4). + %(n_jobs)s + copy : bool + If False, modify the Raw instance in-place. + If True (default), copy the raw instance before processing. + %(verbose)s + + Returns + ------- + raw : instance of Raw + The modified raw instance. + + Notes + ----- + .. versionadded:: 1.10 + + References + ---------- + .. footbibliography:: + """ + # sanity checks + _validate_type(qrs_times, np.ndarray, "qrs_times") + if len(qrs_times.shape) > 1: + raise ValueError("qrs_times must be a 1d array") + if qrs_times.dtype not in [int, float]: + raise ValueError("qrs_times must be an array of either integers or floats") + if np.any(qrs_times < 0): + raise ValueError("qrs_times must be strictly positive") + if np.any(qrs_times >= raw.times[-1]): + logger.warning("some out of bound qrs_times will be ignored..") + + if copy: + raw = raw.copy() + + raw.apply_function( + _pca_obs, + picks=picks, + n_jobs=n_jobs, + # args sent to PCA_OBS, convert times to indices + qrs=raw.time_as_index(qrs_times), + n_components=n_components, + ) + + return raw + + +def _pca_obs( + data: np.ndarray, + qrs: np.ndarray, + n_components: int, +) -> np.ndarray: + """Algorithm to remove heart artefact from EEG data (array of length n_times).""" + # set to baseline + data = data - np.mean(data) + + # Allocate memory for artifact which will be subtracted from the data + fitted_art = np.zeros(data.shape) + + # Extract QRS event indexes which are within out data timeframe + peak_idx = qrs[qrs < len(data)] + peak_count = len(peak_idx) + + ################################################################## + # Preparatory work - reserving memory, configure sizes, de-trend # + ################################################################## + # define peak range based on RR + mRR = np.median(np.diff(peak_idx)) + peak_range = round(mRR / 2) # Rounds to an integer + mid_p = peak_range + 1 + n_samples_fit = round( + peak_range / 8 + ) # sample fit for interpolation between fitted artifact windows + + # make sure array is long enough for PArange (if not cut off last ECG peak) + # NOTE: Here we previously checked for the last part of the window to be big enough. + while peak_idx[peak_count - 1] + peak_range > len(data): + peak_count = peak_count - 1 # reduce number of QRS complexes detected + + # build PCA matrix(heart-beat-epochs x window-length) + pcamat = np.zeros((peak_count - 1, 2 * peak_range + 1)) # [epoch x time] + # picking out heartbeat epochs + for p in range(1, peak_count): + pcamat[p - 1, :] = data[peak_idx[p] - peak_range : peak_idx[p] + peak_range + 1] + + # detrending matrix(twice) + pcamat = detrend( + pcamat, type="constant", axis=1 + ) # [epoch x time] - detrended along the epoch + mean_effect: np.ndarray = np.mean( + pcamat, axis=0 + ) # [1 x time], contains the mean over all epochs + dpcamat = detrend(pcamat, type="constant", axis=1) # [time x epoch] + + ############################ + # Perform PCA with sklearn # + ############################ + # run PCA, perform singular value decomposition (SVD) + pca = _PCA() + pca.fit(dpcamat) + factor_loadings = pca.components_.T * np.sqrt(pca.explained_variance_) + + # define selected number of components using profile likelihood + + ##################################### + # Make template of the ECG artefact # + ##################################### + mean_effect = mean_effect.reshape(-1, 1) + pca_template = np.c_[mean_effect, factor_loadings[:, :n_components]] + + ################ + # Data Fitting # + ################ + window_start_idx = [] + window_end_idx = [] + post_idx_next_peak = None + + for p in range(peak_count): + # if the current peak doesn't have enough data in the + # start of the peak_range, skip fitting the artifact + if peak_idx[p] - peak_range < 0: + continue + + # Deals with start portion of data + if p == 0: + pre_range = peak_range + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) + if post_range > peak_range: + post_range = peak_range + + fitted_art, post_idx_next_peak = _fit_ecg_template( + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_next_peak, + n_samples_fit=n_samples_fit, + ) + # Appending to list instead of using counter + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + + # Deals with last edge of data + elif p == peak_count - 1: + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = peak_range + if pre_range > peak_range: + pre_range = peak_range + fitted_art, _ = _fit_ecg_template( + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_next_peak, + n_samples_fit=n_samples_fit, + ) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + + # Deals with middle portion of data + else: + # ---------------- Processing of central data - -------------------- + # cycle through peak artifacts identified by peakplot + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) + if pre_range >= peak_range: + pre_range = peak_range + if post_range > peak_range: + post_range = peak_range + + a_template = pca_template[ + mid_p - peak_range - 1 : mid_p + peak_range + 1, : + ] + fitted_art, post_idx_next_peak = _fit_ecg_template( + data=data, + pca_template=a_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_next_peak, + n_samples_fit=n_samples_fit, + ) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + + # Actually subtract the artefact, return needs to be the same shape as input data + data -= fitted_art + return data + + +def _fit_ecg_template( + data: np.ndarray, + pca_template: np.ndarray, + a_peak_idx: int, + peak_range: int, + pre_range: int, + post_range: int, + mid_p: float, + fitted_art: np.ndarray, + post_idx_previous_peak: int | None, + n_samples_fit: int, +) -> tuple[np.ndarray, int]: + """ + Fits the heartbeat artefact found in the data. + + Returns the fitted artefact and the index of the next peak. + + Parameters + ---------- + data (ndarray): Data from the raw signal (n_channels, n_times) + pca_template (ndarray): Mean heartbeat and first N (default 4) + principal components of the heartbeat matrix + a_peak_idx (int): Sample index of current R-peak + peak_range (int): Half the median RR-interval + pre_range (int): Number of samples to fit before the R-peak + post_range (int): Number of samples to fit after the R-peak + mid_p (float): Sample index marking middle of the median RR interval + in the signal. Used to extract relevant part of PCA_template. + fitted_art (ndarray): The computed heartbeat artefact computed to + remove from the data + post_idx_previous_peak (optional int): Sample index of previous R-peak + n_samples_fit (int): Sample fit for interpolation in fitted artifact + windows. Helps reduce sharp edges at end of fitted heartbeat events + + Returns + ------- + tuple[np.ndarray, int]: the fitted artifact and the next peak index + """ + # post_idx_next_peak is passed in in PCA_OBS, used here as post_idx_previous_peak + # Then next_peak is returned at the end and the process repeats + # select window of template + template = pca_template[mid_p - peak_range - 1 : mid_p + peak_range + 1, :] + + # select window of data and detrend it + slice_ = data[a_peak_idx - peak_range : a_peak_idx + peak_range + 1] + + detrended_data = detrend(slice_, type="constant") + + # maps data on template and then maps it again back to the sensor space + least_square = np.linalg.lstsq(template, detrended_data, rcond=None) + pad_fit = np.dot(template, least_square[0]) + + # fit artifact + fitted_art[a_peak_idx - pre_range - 1 : a_peak_idx + post_range] = pad_fit[ + mid_p - pre_range - 1 : mid_p + post_range + ].T + + # if last peak, return + if post_idx_previous_peak is None: + return fitted_art, a_peak_idx + post_range + + # interpolate time between peaks + intpol_window = np.ceil([post_idx_previous_peak, a_peak_idx - pre_range]).astype( + int + ) # interpolation window + + if intpol_window[0] < intpol_window[1]: + # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data + + # You have x_fit which is two slices on either side of the interpolation window + # endpoints + # You have y_fit which is the y vals corresponding to x values above + # You have x_interpol which is the time points between the two slices in x_fit + # that you want to interpolate + # You have y_interpol which is values from pchip at the time points specified in + # x_interpol + # points to be interpolated in pt - the gap between the endpoints of the window + x_interpol = np.arange(intpol_window[0], intpol_window[1] + 1, 1) + # Entire range of x values in this step (taking some + # number of samples before and after the window) + x_fit = np.concatenate( + [ + np.arange(intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1), + np.arange(intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1), + ] + ) + y_fit = fitted_art[x_fit] + y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation + + # make fitted artefact in the desired range equal to the completed fit above + fitted_art[post_idx_previous_peak : a_peak_idx - pre_range + 1] = y_interpol + + return fitted_art, a_peak_idx + post_range diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py new file mode 100644 index 00000000000..ee2568a2080 --- /dev/null +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -0,0 +1,107 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +from pathlib import Path + +import numpy as np +import pytest + +from mne.io import read_raw_fif +from mne.io.fiff.raw import Raw +from mne.preprocessing import apply_pca_obs + +data_path = Path(__file__).parents[2] / "io" / "tests" / "data" +raw_fname = data_path / "test_raw.fif" + + +@pytest.fixture() +def short_raw_data(): + """Create a short, picked raw instance.""" + return read_raw_fif(raw_fname, preload=True) + + +def test_heart_artifact_removal(short_raw_data: Raw): + """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" + pd = pytest.importorskip("pandas") + + # copy the original raw. heart artifact is removed in-place + orig_df: pd.DataFrame = short_raw_data.to_data_frame().copy(deep=True) + + # fake some random qrs events in the window of the raw data + # remove first and last samples and cast to integer for indexing + ecg_event_times = np.linspace(0, orig_df["time"].iloc[-1], 20)[1:-1] + + # perform heart artifact removal + short_raw_data = apply_pca_obs( + raw=short_raw_data, picks=["eeg"], qrs_times=ecg_event_times, n_jobs=1 + ) + + # compare processed df to original df + removed_heart_artifact_df: pd.DataFrame = short_raw_data.to_data_frame() + + # ensure all column names remain the same + pd.testing.assert_index_equal( + orig_df.columns, + removed_heart_artifact_df.columns, + ) + + # ensure every column starting with EEG has been altered + altered_cols = [c for c in orig_df.columns if c.startswith("EEG")] + for col in altered_cols: + with pytest.raises( + AssertionError + ): # make sure that error is raised when we check equal + pd.testing.assert_series_equal( + orig_df[col], + removed_heart_artifact_df[col], + ) + + # ensure every column not starting with EEG has not been altered + unaltered_cols = [c for c in orig_df.columns if not c.startswith("EEG")] + pd.testing.assert_frame_equal( + orig_df[unaltered_cols], + removed_heart_artifact_df[unaltered_cols], + ) + + +# test that various nonsensical inputs raise the proper errors +@pytest.mark.parametrize( + ("picks", "qrs_times", "error", "exception"), + [ + ( + ["eeg"], + np.array([[0, 1], [2, 3]]), + "qrs_times must be a 1d array", + ValueError, + ), + ( + ["eeg"], + [2, 3, 4], + "qrs_times must be an instance of ndarray, got instead.", + TypeError, + ), + ( + ["eeg"], + np.array([None, "foo", 2]), + "qrs_times must be an array of either integers or floats", + ValueError, + ), + ( + ["eeg"], + np.array([-1, 0, 3]), + "qrs_times must be strictly positive", + ValueError, + ), + ], +) +def test_pca_obs_bad_input( + short_raw_data: Raw, + picks: list[str], + qrs_times: np.ndarray, + error: str, + exception: type[Exception], +): + """Test if bad input data raises the proper errors in the function sanity checks.""" + with pytest.raises(exception, match=error): + apply_pca_obs(raw=short_raw_data, picks=picks, qrs_times=qrs_times) diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index eed23998774..5029e8fbeca 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -871,6 +871,9 @@ def fit_transform(self, X, y=None): return U + def fit(self, X): + self._fit(X) + def _fit(self, X): if self.n_components is None: n_components = min(X.shape) diff --git a/pyproject.toml b/pyproject.toml index bb56126bc07..f20c495a2bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ doc = [ "mne-gui-addons", "neo", "numpydoc", + "openneuro-py", "psutil", "pydata_sphinx_theme >= 0.15.2", "pygments >= 2.13", diff --git a/tools/circleci_dependencies.sh b/tools/circleci_dependencies.sh index 2ecc9718ab2..dd3216ebf06 100755 --- a/tools/circleci_dependencies.sh +++ b/tools/circleci_dependencies.sh @@ -11,6 +11,6 @@ python -m pip install --upgrade --progress-bar off \ alphaCSC autoreject bycycle conpy emd fooof meggie \ mne-ari mne-bids-pipeline mne-faster mne-features \ mne-icalabel mne-lsl mne-microstates mne-nirs mne-rsa \ - neurodsp neurokit2 niseq nitime openneuro-py pactools \ + neurodsp neurokit2 niseq nitime pactools \ plotly pycrostates pyprep pyriemann python-picard sesameeg \ sleepecg tensorpac yasa meegkit eeg_positions diff --git a/tutorials/preprocessing/50_artifact_correction_ssp.py b/tutorials/preprocessing/50_artifact_correction_ssp.py index 530e6fd39d8..28dee357f9a 100644 --- a/tutorials/preprocessing/50_artifact_correction_ssp.py +++ b/tutorials/preprocessing/50_artifact_correction_ssp.py @@ -390,6 +390,13 @@ # # See the documentation of each function for further details. # +# .. note:: +# In situations only limited electrodes are available for analysis, removing the +# cardiac artefact using techniques which rely on the availability of spatial +# information (such as SSP) may not be possible. In these instances, it may be of +# use to consider algorithms which require information only regarding heartbeat +# instances in the time domain, such as :func:`mne.preprocessing.apply_pca_obs`. +# # # Repairing EOG artifacts with SSP # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -535,6 +542,7 @@ # reduced the amplitude of our signals in sensor space, but that it should not # bias the amplitudes in source space. # +# # References # ^^^^^^^^^^ #