From 4031ff337c90105b5c4ccbe1bcb9eb4b8faa6a72 Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Mon, 16 Dec 2024 16:29:11 +0000 Subject: [PATCH] [GSOC] Add surrogate data generation for significant connectivity estimation (#223) Co-authored-by: Eric Larson --- doc/api.rst | 3 +- doc/references.bib | 21 ++ examples/surrogate_connectivity.py | 353 ++++++++++++++++++ mne_connectivity/__init__.py | 2 +- mne_connectivity/datasets/__init__.py | 1 + mne_connectivity/datasets/surrogate.py | 141 +++++++ .../datasets/tests/test_datasets.py | 316 ++++++++++++++++ mne_connectivity/tests/test_datasets.py | 169 --------- 8 files changed, 835 insertions(+), 171 deletions(-) create mode 100644 examples/surrogate_connectivity.py create mode 100644 mne_connectivity/datasets/surrogate.py create mode 100644 mne_connectivity/datasets/tests/test_datasets.py delete mode 100644 mne_connectivity/tests/test_datasets.py diff --git a/doc/api.rst b/doc/api.rst index 81c844187..1a84a4ad9 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -111,4 +111,5 @@ Dataset functions .. autosummary:: :toctree: generated/ - make_signals_in_freq_bands \ No newline at end of file + make_signals_in_freq_bands + make_surrogate_data \ No newline at end of file diff --git a/doc/references.bib b/doc/references.bib index 536bfad39..55cc7e338 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -61,6 +61,16 @@ @article{Dawson_2016 year = {2016} } +@article{DowdingHaufe2018, + title={Powerful statistical inference for nested data using sufficient summary statistics}, + author={Dowding, Irene and Haufe, Stefan}, + doi={10.3389/fnhum.2018.00103}, + journal={Frontiers in Human Neuroscience}, + volume={12}, + pages={103}, + year={2018} +} + @article{EwaldEtAl2012, author = {Ewald, Arne and Marzetti, Laura and Zappasodi, Filippo and Meinecke, Frank C. and Nolte, Guido}, doi = {10.1016/j.neuroimage.2011.11.084}, @@ -185,6 +195,17 @@ @book{OppenheimEtAl1999 year = {1999} } +@article{PellegriniEtAl2023, + title={Identifying good practices for detecting inter-regional linear functional connectivity from {EEG}}, + author={Pellegrini, Franziska and Delorme, Arnaud and Nikulin, Vadim and Haufe, Stefan}, + doi={10.1016/j.neuroimage.2023.120218}, + journal={NeuroImage}, + volume={277}, + pages={120218}, + year={2023}, + publisher={Elsevier} +} + @book{SekiharaNagarajan2008, author = {Sekihara, Kensuke and Nagarajan, Srikantan S.}, doi = {10.1007/978-3-540-79370-0}, diff --git a/examples/surrogate_connectivity.py b/examples/surrogate_connectivity.py new file mode 100644 index 000000000..763fe1982 --- /dev/null +++ b/examples/surrogate_connectivity.py @@ -0,0 +1,353 @@ +""" +================================================================================== +Determine the significance of connectivity estimates against baseline connectivity +================================================================================== + +This example demonstrates how surrogate data can be generated to assess whether +connectivity estimates are significantly greater than baseline. +""" + +# Author: Thomas S. Binns +# License: BSD (3-clause) +# sphinx_gallery_thumbnail_number = 3 + +# %% + +import matplotlib.pyplot as plt +import mne +import numpy as np +from mne.datasets import somato + +from mne_connectivity import make_surrogate_data, spectral_connectivity_epochs + +######################################################################################## +# Background +# ---------- +# +# When performing connectivity analyses, we often want to know whether the results we +# observe reflect genuine interactions between signals. We can assess this by performing +# statistical tests between our connectivity estimates and a 'baseline' level of +# connectivity. However, due to factors such as background noise and sample +# size-dependent biases (see e.g. :footcite:`VinckEtAl2010`), it is often not +# appropriate to treat 0 as this baseline. Therefore, we need a way to estimate the +# baseline level of connectivity. +# +# One approach is to manipulate the original data in such a way that the covariance +# structure is destroyed, creating surrogate data. Connectivity estimates from the +# original and surrogate data can then be compared to determine whether the original +# data contains significant interactions. +# +# Such surrogate data can be easily generated in MNE using the +# :func:`~mne_connectivity.make_surrogate_data` function, which shuffles epoched data +# independently across channels :footcite:`PellegriniEtAl2023` (see the Notes section of +# the function for more information). In this example, we will demonstrate how surrogate +# data can be created, and how you can use this to assess the statistical significance +# of your connectivity estimates. + +######################################################################################## +# Loading the data +# ---------------- +# +# We start by loading from the :ref:`somato-dataset` dataset, MEG data showing +# event-related activity in response to somatosensory stimuli. We construct epochs +# around these events in the time window [-1.5, 1.0] seconds. + +# %% + +# Load data +data_path = somato.data_path() +raw_fname = data_path / "sub-01" / "meg" / "sub-01_task-somato_meg.fif" +raw = mne.io.read_raw_fif(raw_fname) +events = mne.find_events(raw, stim_channel="STI 014") + +# Pre-processing +raw.pick("grad").load_data() # focus on gradiometers +raw.filter(1, 35) +raw, events = raw.resample(sfreq=100, events=events) # reduce compute time + +# Construct epochs around events +epochs = mne.Epochs( + raw, events, event_id=1, tmin=-1.5, tmax=1.0, baseline=(-0.5, 0), preload=True +) +epochs = epochs[:30] # select a subset of epochs to speed up computation + +######################################################################################## +# Assessing connectivity in non-evoked data +# ----------------------------------------- +# +# We will first demonstrate how connectivity can be assessed from non-evoked data. In +# this example, we use data from the pre-trial period of [-1.5, -0.5] seconds. We +# compute Fourier coefficients of the data using the :meth:`~mne.Epochs.compute_psd` +# method with ``output="complex"`` (note that this requires ``mne >= 1.8``). +# +# Next, we pass these coefficients to +# :func:`~mne_connectivity.spectral_connectivity_epochs` to compute connectivity using +# the imaginary part of coherency (``imcoh``). Our indices specify that connectivity +# should be computed between all pairs of channels. + +# %% + +# Compute Fourier coefficients for pre-trial data +fmin, fmax = 3, 23 +pretrial_coeffs = epochs.compute_psd( + fmin=fmin, fmax=fmax, tmin=None, tmax=-0.5, output="complex" +) +freqs = pretrial_coeffs.freqs + +# Compute connectivity for pre-trial data +indices = np.tril_indices(epochs.info["nchan"], k=-1) # all-to-all connectivity +pretrial_con = spectral_connectivity_epochs( + pretrial_coeffs, method="imcoh", indices=indices +) + +######################################################################################## +# Next, we generate the surrogate data by passing the Fourier coefficients into the +# :func:`~mne_connectivity.make_surrogate_data` function. To get a reliable estimate of +# the baseline connectivity, we perform this shuffling procedure +# :math:`\text{n}_{\text{shuffle}}` times, producing :math:`\text{n}_{\text{shuffle}}` +# surrogate datasets. We can then iterate over these shuffles and compute the +# connectivity for each one. + +# %% + +# Generate surrogate data +n_shuffles = 100 # recommended is >= 1,000; limited here to reduce compute time +pretrial_surrogates = make_surrogate_data( + pretrial_coeffs, n_shuffles=n_shuffles, rng_seed=44 +) + +# Compute connectivity for surrogate data +surrogate_con = [] +for shuffle_i, surrogate in enumerate(pretrial_surrogates, 1): + print(f"Computing connectivity for shuffle {shuffle_i} of {n_shuffles}") + surrogate_con.append( + spectral_connectivity_epochs( + surrogate, method="imcoh", indices=indices, verbose=False + ) + ) + +######################################################################################## +# We can plot the all-to-all connectivity of the pre-trial data against the surrogate +# data, averaged over all shuffles. This shows a strong degree of coupling in the alpha +# band (~8-12 Hz), with weaker coupling in the lower range of the beta band (~13-20 Hz). +# A simple visual inspection shows that connectivity in the alpha and beta bands are +# above the baseline level of connectivity estimated from the surrogate data. However, +# we need to confirm this statistically. + +# %% + +# Plot pre-trial vs. surrogate connectivity +fig, ax = plt.subplots(1, 1) +ax.plot( + freqs, + np.abs([surrogate.get_data() for surrogate in surrogate_con]).mean(axis=(0, 1)), + linestyle="--", + label="Surrogate", +) +ax.plot(freqs, np.abs(pretrial_con.get_data()).mean(axis=0), label="Original") +ax.set_xlabel("Frequency (Hz)") +ax.set_ylabel("Connectivity (A.U.)") +ax.set_title("All-to-all connectivity | Pre-trial ") +ax.legend() + +######################################################################################## +# Assessing the statistical significance of our connectivity estimates can be done with +# the following simple procedure :footcite:`PellegriniEtAl2023` +# +# :math:`p=\LARGE{\frac{\Sigma_{s=1}^Sc_s}{S}}` , +# +# :math:`c_s=\{1\text{ if }\text{Con}\leq\text{Con}_{\text{s}}\text{ },\text{ }0 +# \text{ if otherwise }` , +# +# where: :math:`p` is our p-value; :math:`s` is a given shuffle iteration of :math:`S` +# total shuffles; and :math:`c` is a binary indicator of whether the true connectivity, +# :math:`\text{Con}`, is greater than the surrogate connectivity, +# :math:`\text{Con}_{\text{s}}`, for a given shuffle. +# +# Note that for connectivity methods which produce negative scores (e.g., imaginary part +# of coherency, time-reversed Granger causality, etc...), you should take the absolute +# values before testing. Similar adjustments should be made for methods that produce +# scores centred around non-zero values (e.g., 0.5 for directed phase lag index). +# +# Below, we determine the statistical significance of connectivity in the lower beta +# band. We simplify this by averaging over all connections and corresponding frequency +# bins. We could of course also test the significance of each connection, each frequency +# bin, or other frequency bands such as the alpha band. Naturally, any tests involving +# multiple connections, frequencies, and/or times should be corrected for multiple +# comparisons. +# +# The test confirms our visual inspection, showing that connectivity in the lower beta +# band is significantly above the baseline level of connectivity at an alpha of 0.05, +# which we can take as evidence of genuine interactions in this frequency band. + +# %% + +# Find indices of lower beta frequencies +beta_freqs = np.where((freqs >= 13) & (freqs <= 20))[0] + +# Compute lower beta connectivity for pre-trial data (average connections and freqs) +beta_con_pretrial = np.abs(pretrial_con.get_data()[:, beta_freqs]).mean(axis=(0, 1)) + +# Compute lower beta connectivity for surrogate data (average connections and freqs) +beta_con_surrogate = np.abs( + [surrogate.get_data()[:, beta_freqs] for surrogate in surrogate_con] +).mean(axis=(1, 2)) + +# Compute p-value for pre-trial lower beta coupling +p_val = np.sum(beta_con_pretrial <= beta_con_surrogate) / n_shuffles +print(f"P = {p_val:.2f}") + +######################################################################################## +# Assessing connectivity in evoked data +# ------------------------------------- +# +# When generating surrogate data, it is important to distinguish non-evoked data (e.g., +# resting-state, pre/inter-trial data) from evoked data (where a stimulus is presented +# or an action performed at a set time during each epoch). Critically, evoked data +# contains a temporal structure that is consistent across epochs, and thus shuffling +# epochs across channels will fail to adequately disrupt the covariance structure. +# +# Any connectivity estimates will therefore overestimate the baseline connectivity in +# your data, increasing the likelihood of type II errors (see the Notes section of +# :func:`~mne_connectivity.make_surrogate_data` for more information, and see the +# section :ref:`inappropriate-surrogate-data` for a demonstration). +# +# **In cases where you want to assess connectivity in evoked data, you can use +# surrogates generated from non-evoked data (of the same subject).** Here we do just +# that, comparing connectivity estimates from the pre-trial surrogates to the evoked, +# post-stimulus response ([0, 1] second). +# +# Again, there is pronounced alpha coupling (stronger than in the pre-trial data) and +# weaker beta coupling, both of which appear to be above the baseline level of +# connectivity. + +# %% + +# Compute Fourier coefficients for post-stimulus data +poststim_coeffs = epochs.compute_psd( + fmin=fmin, fmax=fmax, tmin=0, tmax=None, output="complex" +) + +# Compute connectivity for post-stimulus data +poststim_con = spectral_connectivity_epochs( + poststim_coeffs, method="imcoh", indices=indices +) + +# Plot post-stimulus vs. (pre-trial) surrogate connectivity +fig, ax = plt.subplots(1, 1) +ax.plot( + freqs, + np.abs([surrogate.get_data() for surrogate in surrogate_con]).mean(axis=(0, 1)), + linestyle="--", + label="Surrogate", +) +ax.plot(freqs, np.abs(poststim_con.get_data()).mean(axis=0), label="Original") +ax.set_xlabel("Frequency (Hz)") +ax.set_ylabel("Connectivity (A.U.)") +ax.set_title("All-to-all connectivity | Post-stimulus") +ax.legend() + +######################################################################################## +# This is also confirmed by statistical testing, with connectivity in the lower beta +# band being significantly above the baseline level of connectivity. Thus, using +# surrogate connectivity estimates from non-evoked data provides a reliable baseline for +# assessing connectivity in evoked data. + +# %% + +# Compute lower beta connectivity for post-stimulus data (average connections and freqs) +beta_con_poststim = np.abs(poststim_con.get_data()[:, beta_freqs]).mean(axis=(0, 1)) + +# Compute p-value for post-stimulus lower beta coupling +p_val = np.sum(beta_con_poststim <= beta_con_surrogate) / n_shuffles +print(f"P = {p_val:.2f}") + +######################################################################################## +# .. _inappropriate-surrogate-data: +# +# Generating surrogate connectivity from inappropriate data +# --------------------------------------------------------- +# We discussed above how surrogates generated from evoked data risk overestimating the +# degree of baseline connectivity. We demonstrate this below by generating surrogates +# from the post-stimulus data. + +# %% + +# Generate surrogates from evoked data +poststim_surrogates = make_surrogate_data( + poststim_coeffs, n_shuffles=n_shuffles, rng_seed=44 +) + +# Compute connectivity for evoked surrogate data +bad_surrogate_con = [] +for shuffle_i, surrogate in enumerate(poststim_surrogates, 1): + print(f"Computing connectivity for shuffle {shuffle_i} of {n_shuffles}") + bad_surrogate_con.append( + spectral_connectivity_epochs( + surrogate, method="imcoh", indices=indices, verbose=False + ) + ) + +######################################################################################## +# Plotting the post-stimulus connectivity against the estimates from the non-evoked and +# evoked surrogate data, we see that the evoked surrogate data greatly overestimates the +# baseline connectivity in the alpha band. +# +# Although in this case the alpha connectivity was still far above the baseline from the +# evoked surrogates, this will not always be the case, and you can see how this risks +# false negative assessments that connectivity is not significantly different from +# baseline. + +# %% + +# Plot post-stimulus vs. evoked and non-evoked surrogate connectivity +fig, ax = plt.subplots(1, 1) +ax.plot( + freqs, + np.abs([surrogate.get_data() for surrogate in surrogate_con]).mean(axis=(0, 1)), + linestyle="--", + label="Surrogate (pre-stimulus)", +) +ax.plot( + freqs, + np.abs([surrogate.get_data() for surrogate in bad_surrogate_con]).mean(axis=(0, 1)), + color="C3", + linestyle="--", + label="Surrogate (post-stimulus)", +) +ax.plot( + freqs, np.abs(poststim_con.get_data()).mean(axis=0), color="C1", label="Original" +) +ax.set_xlabel("Frequency (Hz)") +ax.set_ylabel("Connectivity (A.U.)") +ax.set_title("All-to-all connectivity | Post-stimulus") +ax.legend() + +######################################################################################## +# Assessing connectivity on a group-level +# --------------------------------------- +# +# While our focus here has been on assessing the significance of connectivity on a +# single recording-level, we may also want to determine whether group-level connectivity +# estimates are significantly different from baseline. For this, we can generate +# surrogates and estimate connectivity alongside the original signals for each piece of +# data. +# +# There are multiple ways to assess the statistical significance. For example, we can +# compute p-values for each piece of data using the approach above and combine them for +# the nested data (e.g., across recordings, subjects, etc...) using Stouffer's method +# :footcite:`DowdingHaufe2018`. +# +# Alternatively, we could take the average of the surrogate connectivity estimates +# across all shuffles for each piece of data and compare them to the original +# connectivity estimates in a paired test. The :mod:`scipy.stats` and :mod:`mne.stats` +# modules have many such tools for testing this, e.g., :func:`scipy.stats.ttest_1samp`, +# :func:`mne.stats.permutation_t_test`, etc... +# +# Altogether, surrogate connectivity estimates are a powerful tool for assessing the +# significance of connectivity estimates, both on a single recording- and group-level. + +######################################################################################## +# References +# ---------- +# .. footbibliography:: diff --git a/mne_connectivity/__init__.py b/mne_connectivity/__init__.py index 480194e8f..80c7fe4e7 100644 --- a/mne_connectivity/__init__.py +++ b/mne_connectivity/__init__.py @@ -23,7 +23,7 @@ SpectroTemporalConnectivity, TemporalConnectivity, ) -from .datasets import make_signals_in_freq_bands +from .datasets import make_signals_in_freq_bands, make_surrogate_data from .decoding import CoherencyDecomposition from .effective import phase_slope_index from .envelope import envelope_correlation, symmetric_orth diff --git a/mne_connectivity/datasets/__init__.py b/mne_connectivity/datasets/__init__.py index d5c8e2eb8..dc2296ee0 100644 --- a/mne_connectivity/datasets/__init__.py +++ b/mne_connectivity/datasets/__init__.py @@ -1 +1,2 @@ from .frequency import make_signals_in_freq_bands +from .surrogate import make_surrogate_data diff --git a/mne_connectivity/datasets/surrogate.py b/mne_connectivity/datasets/surrogate.py new file mode 100644 index 000000000..a482b5c60 --- /dev/null +++ b/mne_connectivity/datasets/surrogate.py @@ -0,0 +1,141 @@ +# Authors: Thomas S. Binns +# +# License: BSD (3-clause) + +import numpy as np +from mne.time_frequency import EpochsSpectrum, EpochsSpectrumArray +from mne.utils import _validate_type + + +def make_surrogate_data(data, n_shuffles=1000, rng_seed=None, return_generator=True): + """Create surrogate data for a null hypothesis of connectivity. + + Parameters + ---------- + data : ~mne.time_frequency.EpochsSpectrum | ~mne.time_frequency.EpochsSpectrumArray + The Fourier coefficients to create the null hypothesis surrogate data for. Can + be generated from :meth:`mne.Epochs.compute_psd` with ``output='complex'`` + (requires ``mne >= 1.8``). + n_shuffles : int (default 1000) + The number of surrogate datasets to create. + rng_seed : int | None (default None) + The seed to use for the random number generator. If `None`, no seed is + specified. + return_generator : bool (default True) + Whether or not to return the surrogate data as a :term:`generator` object + instead of a :class:`list`. This allows iterating over the surrogates without + having to keep them all in memory. + + Returns + ------- + surrogate_data : list of ~mne.time_frequency.EpochsSpectrum + The surrogate data for the null hypothesis with ``n_shuffles`` entries. Returned + as a :term:`generator` if ``return_generator=True``. + + Notes + ----- + Surrogate data is generated by randomly shuffling the order of epochs, independently + for each channel. This destroys the covariance of the data, such that connectivity + estimates should reflect the null hypothesis of no genuine connectivity between + signals (e.g., only interactions due to background noise) + :footcite:`PellegriniEtAl2023`. + + For the surrogate data to properly reflect a null hypothesis, the data which is + shuffled **must not** have a temporal structure that is consistent across epochs. + Examples of this data include evoked potentials, where a stimulus is presented or an + action performed at a set time during each epoch. Such data should not be used for + generating surrogates, as even after shuffling the epochs, it will still show a high + degree of residual connectivity between channels. As a result, connectivity + estimates from your surrogate data will capture genuine interactions, instead of the + desired background noise. Treating these estimates as a null hypothesis will + increase the likelihood of a type II (false negative) error, i.e., that there is no + significant connectivity in your data. + + Appropriate data for generating surrogates includes data from a resting state, + inter-trial period, or similar. Here, a strong temporal consistency across epochs is + not assumed, reducing the chances that connectivity information of interest is + captured in your surrogate connectivity estimates. + + In situations where you want to assess whether evoked data has significant + connectivity, you can generate your surrogate connectivity estimates from non-evoked + data (e.g., rest data, inter-trial data) and compare this to your true connectivity + estimates from the evoked data. + + Regardless of whether you are working with evoked or non-evoked data, **you should + always compare true and surrogate connectivity estimates from epochs of the same + duration**. This will ensure that spectral information is captured with the same + accuracy in both sets of connectivity estimates. Ideally, **you should also compare + true and surrogate connectivity estimates from the same number of epochs** to avoid + biases from noise (fewer epochs gives noisier estimates) or finite sample sizes + (e.g., in coherency, phase-locking value, etc... :footcite:`VinckEtAl2010`). + + .. versionadded:: 0.8 + + References + ---------- + .. footbibliography:: + """ + # Validate inputs + _validate_type( + data, + (EpochsSpectrum, EpochsSpectrumArray), + "data", + "mne.time_frequency.EpochsSpectrum or mne.time_frequency.EpochsSpectrumArray", + ) + if not np.iscomplexobj(data.get_data()): + raise TypeError("values in `data` must be complex-valued") + n_epochs, n_chans = data.get_data().shape[:2] + if n_epochs == 1: + raise ValueError("data must contain more than one epoch for shuffling") + if n_chans == 1: + raise ValueError("data must contain more than one channel for shuffling") + + _validate_type(n_shuffles, "int-like", "n_shuffles", "int") + if n_shuffles < 1: + raise ValueError("number of shuffles must be >= 1") + + _validate_type(return_generator, bool, "return_generator", "bool") + # rng_seed checked by NumPy later + + # Make surrogate data and package into EpochsSpectrum objects + surrogate_data = _shuffle_coefficients(data, n_shuffles, rng_seed) + if not return_generator: + surrogate_data = [shuffle for shuffle in surrogate_data] + + return surrogate_data + + +def _shuffle_coefficients(data, n_shuffles, rng_seed): + """Shuffle coefficients over epochs to create surrogate data. + + Surrogate data for each shuffle packaged into an EpochsSpectrum object, which are + together returned as a generator to minimise memory demand. + """ + # Extract data array and EpochsSpectrum information + data_arr = data.get_data() + state = data.__getstate__() + defaults = dict( + method=None, + fmin=None, + fmax=None, + tmin=None, + tmax=None, + picks=None, + exclude=(), + proj=None, + remove_dc=None, + n_jobs=None, + verbose=None, + ) + + # Make surrogate data + rng = np.random.default_rng(rng_seed) + for _ in range(n_shuffles): + # Shuffle epochs for each channel independently + surrogate_arr = np.zeros_like(data_arr, dtype=data_arr.dtype) + for chan_i in range(data_arr.shape[1]): + surrogate_arr[:, chan_i] = rng.permutation(data_arr[:, chan_i], axis=0) + + # Package surrogate data for this shuffle + state["data"] = surrogate_arr + yield EpochsSpectrum(state, **defaults) # return surrogate data as a generator diff --git a/mne_connectivity/datasets/tests/test_datasets.py b/mne_connectivity/datasets/tests/test_datasets.py new file mode 100644 index 000000000..70edccb54 --- /dev/null +++ b/mne_connectivity/datasets/tests/test_datasets.py @@ -0,0 +1,316 @@ +from collections.abc import Generator + +import numpy as np +import pytest +from mne import create_info +from mne.time_frequency import EpochsSpectrumArray + +from mne_connectivity import ( + make_signals_in_freq_bands, + make_surrogate_data, + seed_target_indices, + spectral_connectivity_epochs, +) + + +@pytest.mark.parametrize("n_seeds", [1, 3]) +@pytest.mark.parametrize("n_targets", [1, 3]) +@pytest.mark.parametrize("snr", [0.7, 0.4]) +@pytest.mark.parametrize("connection_delay", [0, 3, -3]) +@pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) +def test_make_signals_in_freq_bands(n_seeds, n_targets, snr, connection_delay, mode): + """Test `make_signals_in_freq_bands` simulates connectivity properly.""" + # Case with no spurious correlations (avoids tests randomly failing) + rng_seed = 0 + + # Simulate data + freq_band = (5, 10) # fmin, fmax (Hz) + sfreq = 100 # Hz + trans_bandwidth = 1 # Hz + data = make_signals_in_freq_bands( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=freq_band, + n_epochs=30, + n_times=200, + sfreq=sfreq, + trans_bandwidth=trans_bandwidth, + snr=snr, + connection_delay=connection_delay, + rng_seed=rng_seed, + ) + + # Compute connectivity + methods = ["coh", "imcoh", "dpli"] + indices = seed_target_indices( + seeds=np.arange(n_seeds), targets=np.arange(n_targets) + n_seeds + ) + fmin = 3 + fmax = sfreq // 2 + if mode == "cwt_morlet": + cwt_params = {"cwt_freqs": np.arange(fmin, fmax), "cwt_n_cycles": 3.5} + else: + cwt_params = dict() + con = spectral_connectivity_epochs( + data, + method=methods, + indices=indices, + mode=mode, + fmin=fmin, + fmax=fmax, + **cwt_params, + ) + freqs = np.array(con[0].freqs) + + # Define expected connectivity values + thresh_good = dict() + thresh_bad = dict() + # Coh + thresh_good["coh"] = (0.2, 0.9) + thresh_bad["coh"] = (0.0, 0.2) + # ImCoh + if connection_delay == 0: + thresh_good["imcoh"] = (0.0, 0.17) + thresh_bad["imcoh"] = (0.0, 0.17) + else: + thresh_good["imcoh"] = (0.17, 0.8) + thresh_bad["imcoh"] = (0.0, 0.17) + # DPLI + if connection_delay == 0: + thresh_good["dpli"] = (0.3, 0.6) + thresh_bad["dpli"] = (0.3, 0.6) + elif connection_delay > 0: + thresh_good["dpli"] = (0.5, 1) + thresh_bad["dpli"] = (0.3, 0.6) + else: + thresh_good["dpli"] = (0, 0.5) + thresh_bad["dpli"] = (0.3, 0.6) + + # Check connectivity values are acceptable + freqs_good = np.argwhere( + (freqs >= freq_band[0]) & (freqs <= freq_band[1]) + ).flatten() + freqs_bad = np.argwhere( + (freqs < freq_band[0] - trans_bandwidth * 2) + | (freqs > freq_band[1] + trans_bandwidth * 2) + ).flatten() + for method_name, method_con in zip(methods, con): + con_values = method_con.get_data() + if method_name == "imcoh": + con_values = np.abs(con_values) + # freq. band of interest + con_values_good = np.mean(con_values[:, freqs_good]) + assert ( + con_values_good >= thresh_good[method_name][0] + and con_values_good <= thresh_good[method_name][1] + ) + + # other freqs. + con_values_bad = np.mean(con_values[:, freqs_bad]) + assert ( + con_values_bad >= thresh_bad[method_name][0] + and con_values_bad <= thresh_bad[method_name][1] + ) + + +def test_make_signals_in_freq_bands_error_catch(): + """Test error catching for `make_signals_in_freq_bands`.""" + freq_band = (5, 10) + + # check bad n_seeds/targets caught + with pytest.raises( + ValueError, match="Number of seeds and targets must each be at least 1." + ): + make_signals_in_freq_bands(n_seeds=0, n_targets=1, freq_band=freq_band) + with pytest.raises( + ValueError, match="Number of seeds and targets must each be at least 1." + ): + make_signals_in_freq_bands(n_seeds=1, n_targets=0, freq_band=freq_band) + + # check bad freq_band caught + with pytest.raises(TypeError, match="Frequency band must be a tuple."): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=1) + with pytest.raises(ValueError, match="Frequency band must contain two numbers."): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=(1, 2, 3)) + + # check bad n_times + with pytest.raises(ValueError, match="Number of timepoints must be at least 1."): + make_signals_in_freq_bands( + n_seeds=1, n_targets=1, freq_band=freq_band, n_times=0 + ) + + # check bad n_epochs + with pytest.raises(ValueError, match="Number of epochs must be at least 1."): + make_signals_in_freq_bands( + n_seeds=1, n_targets=1, freq_band=freq_band, n_epochs=0 + ) + + # check bad sfreq + with pytest.raises(ValueError, match="Sampling frequency must be > 0."): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, sfreq=0) + + # check bad snr + with pytest.raises( + ValueError, match="Signal-to-noise ratio must be between 0 and 1." + ): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, snr=-1) + with pytest.raises( + ValueError, match="Signal-to-noise ratio must be between 0 and 1." + ): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, snr=2) + + # check bad connection_delay + with pytest.raises( + ValueError, + match="Connection delay must be less than the total number of timepoints.", + ): + make_signals_in_freq_bands( + n_seeds=1, + n_targets=1, + freq_band=freq_band, + n_epochs=1, + n_times=1, + connection_delay=1, + ) + + +@pytest.mark.parametrize(("snr", "should_be_significant"), ([0.3, True], [0.1, False])) +@pytest.mark.parametrize("mode", ["multitaper", "fourier"]) +def test_make_surrogate_data(snr, should_be_significant, mode): + """Test `make_surrogate_data` creates data for null hypothesis testing.""" + # Generate data + n_seeds = 2 + n_targets = 2 + freq_band = (10, 15) + n_epochs = 30 + sfreq = 100 + n_times = sfreq * 2 + n_shuffles = 1000 + rng_seed = 44 + data = make_signals_in_freq_bands( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=freq_band, + n_epochs=n_epochs, + n_times=n_times, + sfreq=sfreq, + snr=snr, # using very high SNR seems to alter properties of data beyond fband + rng_seed=rng_seed, + ) + indices = seed_target_indices( + seeds=np.arange(n_seeds), targets=np.arange(n_targets) + n_seeds + ) + + # Compute Fourier coefficients and generate surrogates + spectrum = data.compute_psd( + method="welch" if mode == "fourier" else mode, output="complex" + ) + surrogate_spectrum = make_surrogate_data( + data=spectrum, n_shuffles=1000, rng_seed=rng_seed + ) + + # Compute connectivity + con = spectral_connectivity_epochs(data=spectrum, method="coh", indices=indices) + freqs = np.array(con.freqs) + connectivity = np.zeros((n_shuffles + 1, *con.shape)) + connectivity[0] = con.get_data() # first entry is original data + for shuffle_i, shuffle_data in enumerate(surrogate_spectrum): + connectivity[shuffle_i + 1] = spectral_connectivity_epochs( + data=shuffle_data, method="coh", indices=indices, verbose=False + ).get_data() + + # Determine if connectivity significant + alpha = 0.05 + con_freqs = (freqs >= freq_band[0]) & (freqs <= freq_band[1]) + noise_freqs = np.invert(con_freqs) + + pval_con_freqs = ( + np.sum( + np.mean(connectivity[0, :, con_freqs]) # aggregate cons and freqs + <= np.mean(connectivity[1:, :, con_freqs], axis=(1, 2)) # same aggr. here + ) + / n_shuffles + ) + + pval_noise_freqs = ( + np.sum( + np.mean(connectivity[0, :, noise_freqs]) + <= np.mean(connectivity[1:, :, noise_freqs], axis=(1, 2)) + ) + / n_shuffles + ) + + if should_be_significant: + assert pval_con_freqs < alpha, f"pval_con_freqs: {pval_con_freqs}" + else: + assert pval_con_freqs >= alpha, f"pval_con_freqs: {pval_con_freqs}" + + # Freqs where nothing simulated should never be significant + assert pval_noise_freqs > alpha, f"pval_noise_freqs: {pval_noise_freqs}" + + +def test_make_surrogate_data_generator(): + """Test `return_generator` parameter works in `make_surrogate_data`.""" + # Generate random data for packaging into EpochsSpectrum + n_epochs = 5 + n_chans = 6 + n_freqs = 50 + sfreq = n_freqs * 2 + rng = np.random.default_rng(44) + data = rng.random((n_epochs, n_chans, n_freqs)).astype(np.complex128) + data += data * 1j # complex dtypes not supported for simulation, so make complex + info = create_info(ch_names=n_chans, sfreq=sfreq, ch_types="eeg") + spectrum = EpochsSpectrumArray(data=data, info=info, freqs=np.arange(n_freqs)) + + # Test generator (not) returned when requested + surrogate_data = make_surrogate_data(data=spectrum, return_generator=True) + assert isinstance(surrogate_data, Generator), type(surrogate_data) + surrogate_data = make_surrogate_data(data=spectrum, return_generator=False) + assert isinstance(surrogate_data, list), type(surrogate_data) + + +def test_make_surrogate_data_error_catch(): + """Test error catching for `make_surrogate_data`.""" + # Generate random data for packaging into EpochsSpectrum + n_epochs = 5 + n_chans = 6 + n_freqs = 50 + sfreq = n_freqs * 2 + rng = np.random.default_rng(44) + data = rng.random((n_epochs, n_chans, n_freqs)).astype(np.complex128) + data += data * 1j # complex dtypes not supported for simulation, so make complex + info = create_info(ch_names=n_chans, sfreq=sfreq, ch_types="eeg") + spectrum = EpochsSpectrumArray(data=data, info=info, freqs=np.arange(n_freqs)) + + # check bad data + with pytest.raises(TypeError, match=r"data must be an instance of.*EpochsSpectrum"): + make_surrogate_data(data=data) + with pytest.raises(TypeError, match="values in `data` must be complex-valued"): + bad_dtype_data = EpochsSpectrumArray( + data=np.abs(data), info=info, freqs=np.arange(n_freqs) + ) + make_surrogate_data(data=bad_dtype_data) + with pytest.raises(ValueError, match="data must contain more than one epoch"): + bad_nepochs_data = EpochsSpectrumArray( + data=data[[0]], info=info, freqs=np.arange(n_freqs) + ) + make_surrogate_data(data=bad_nepochs_data) + with pytest.raises(ValueError, match="data must contain more than one channel"): + bad_nchans_data = EpochsSpectrumArray( + data=data[:, [0]], + info=create_info(ch_names=1, sfreq=sfreq, ch_types="eeg"), + freqs=np.arange(n_freqs), + ) + make_surrogate_data(data=bad_nchans_data) + + # check bad n_shuffles + with pytest.raises(TypeError, match="n_shuffles must be an instance of int"): + make_surrogate_data(data=spectrum, n_shuffles="all") + with pytest.raises(ValueError, match="number of shuffles must be >= 1"): + make_surrogate_data(data=spectrum, n_shuffles=0) + with pytest.raises(ValueError, match="number of shuffles must be >= 1"): + make_surrogate_data(data=spectrum, n_shuffles=-1) + + # check bad return_generator + with pytest.raises(TypeError, match="return_generator must be an instance of bool"): + make_surrogate_data(data=spectrum, return_generator="yes") diff --git a/mne_connectivity/tests/test_datasets.py b/mne_connectivity/tests/test_datasets.py deleted file mode 100644 index 4ae7d5ac2..000000000 --- a/mne_connectivity/tests/test_datasets.py +++ /dev/null @@ -1,169 +0,0 @@ -import numpy as np -import pytest - -from mne_connectivity import ( - make_signals_in_freq_bands, - seed_target_indices, - spectral_connectivity_epochs, -) - - -@pytest.mark.parametrize("n_seeds", [1, 3]) -@pytest.mark.parametrize("n_targets", [1, 3]) -@pytest.mark.parametrize("snr", [0.7, 0.4]) -@pytest.mark.parametrize("connection_delay", [0, 3, -3]) -@pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) -def test_make_signals_in_freq_bands(n_seeds, n_targets, snr, connection_delay, mode): - """Test `make_signals_in_freq_bands` simulates connectivity properly.""" - # Case with no spurious correlations (avoids tests randomly failing) - rng_seed = 0 - - # Simulate data - freq_band = (5, 10) # fmin, fmax (Hz) - sfreq = 100 # Hz - trans_bandwidth = 1 # Hz - data = make_signals_in_freq_bands( - n_seeds=n_seeds, - n_targets=n_targets, - freq_band=freq_band, - n_epochs=30, - n_times=200, - sfreq=sfreq, - trans_bandwidth=trans_bandwidth, - snr=snr, - connection_delay=connection_delay, - rng_seed=rng_seed, - ) - - # Compute connectivity - methods = ["coh", "imcoh", "dpli"] - indices = seed_target_indices( - seeds=np.arange(n_seeds), targets=np.arange(n_targets) + n_seeds - ) - fmin = 3 - fmax = sfreq // 2 - if mode == "cwt_morlet": - cwt_params = {"cwt_freqs": np.arange(fmin, fmax), "cwt_n_cycles": 3.5} - else: - cwt_params = dict() - con = spectral_connectivity_epochs( - data, - method=methods, - indices=indices, - mode=mode, - fmin=fmin, - fmax=fmax, - **cwt_params, - ) - freqs = np.array(con[0].freqs) - - # Define expected connectivity values - thresh_good = dict() - thresh_bad = dict() - # Coh - thresh_good["coh"] = (0.2, 0.9) - thresh_bad["coh"] = (0.0, 0.2) - # ImCoh - if connection_delay == 0: - thresh_good["imcoh"] = (0.0, 0.17) - thresh_bad["imcoh"] = (0.0, 0.17) - else: - thresh_good["imcoh"] = (0.17, 0.8) - thresh_bad["imcoh"] = (0.0, 0.17) - # DPLI - if connection_delay == 0: - thresh_good["dpli"] = (0.3, 0.6) - thresh_bad["dpli"] = (0.3, 0.6) - elif connection_delay > 0: - thresh_good["dpli"] = (0.5, 1) - thresh_bad["dpli"] = (0.3, 0.6) - else: - thresh_good["dpli"] = (0, 0.5) - thresh_bad["dpli"] = (0.3, 0.6) - - # Check connectivity values are acceptable - freqs_good = np.argwhere( - (freqs >= freq_band[0]) & (freqs <= freq_band[1]) - ).flatten() - freqs_bad = np.argwhere( - (freqs < freq_band[0] - trans_bandwidth * 2) - | (freqs > freq_band[1] + trans_bandwidth * 2) - ).flatten() - for method_name, method_con in zip(methods, con): - con_values = method_con.get_data() - if method_name == "imcoh": - con_values = np.abs(con_values) - # freq. band of interest - con_values_good = np.mean(con_values[:, freqs_good]) - assert ( - con_values_good >= thresh_good[method_name][0] - and con_values_good <= thresh_good[method_name][1] - ) - - # other freqs. - con_values_bad = np.mean(con_values[:, freqs_bad]) - assert ( - con_values_bad >= thresh_bad[method_name][0] - and con_values_bad <= thresh_bad[method_name][1] - ) - - -def test_make_signals_error_catch(): - """Test error catching for `make_signals_in_freq_bands`.""" - freq_band = (5, 10) - - # check bad n_seeds/targets caught - with pytest.raises( - ValueError, match="Number of seeds and targets must each be at least 1." - ): - make_signals_in_freq_bands(n_seeds=0, n_targets=1, freq_band=freq_band) - with pytest.raises( - ValueError, match="Number of seeds and targets must each be at least 1." - ): - make_signals_in_freq_bands(n_seeds=1, n_targets=0, freq_band=freq_band) - - # check bad freq_band caught - with pytest.raises(TypeError, match="Frequency band must be a tuple."): - make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=1) - with pytest.raises(ValueError, match="Frequency band must contain two numbers."): - make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=(1, 2, 3)) - - # check bad n_times - with pytest.raises(ValueError, match="Number of timepoints must be at least 1."): - make_signals_in_freq_bands( - n_seeds=1, n_targets=1, freq_band=freq_band, n_times=0 - ) - - # check bad n_epochs - with pytest.raises(ValueError, match="Number of epochs must be at least 1."): - make_signals_in_freq_bands( - n_seeds=1, n_targets=1, freq_band=freq_band, n_epochs=0 - ) - - # check bad sfreq - with pytest.raises(ValueError, match="Sampling frequency must be > 0."): - make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, sfreq=0) - - # check bad snr - with pytest.raises( - ValueError, match="Signal-to-noise ratio must be between 0 and 1." - ): - make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, snr=-1) - with pytest.raises( - ValueError, match="Signal-to-noise ratio must be between 0 and 1." - ): - make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, snr=2) - - # check bad connection_delay - with pytest.raises( - ValueError, - match="Connection delay must be less than the total number of timepoints.", - ): - make_signals_in_freq_bands( - n_seeds=1, - n_targets=1, - freq_band=freq_band, - n_epochs=1, - n_times=1, - connection_delay=1, - )