From 6debece5f02f6d3f393c9191f4d98f36159c3c20 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 1 Aug 2024 15:59:49 +0200 Subject: [PATCH 01/18] Add EpochsSpectrum support --- .circleci/config.yml | 2 +- .github/workflows/linux_conda.yml | 2 +- .github/workflows/unit_tests.yml | 2 +- examples/sensor_connectivity.py | 31 +- mne_connectivity/spectral/epochs.py | 571 +++++++++++------- .../spectral/tests/test_spectral.py | 95 ++- mne_connectivity/spectral/time.py | 2 + 7 files changed, 471 insertions(+), 234 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 62cffbb18..99fcfa338 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -77,7 +77,7 @@ jobs: - run: name: Get Python running and install dependencies command: | - pip install git+https://github.com/mne-tools/mne-python@main + pip install git+https://github.com/tsbinns/mne-python/tree/complex_spectrum curl https://raw.githubusercontent.com/mne-tools/mne-python/main/tools/circleci_dependencies.sh -o circleci_dependencies.sh chmod +x circleci_dependencies.sh ./circleci_dependencies.sh diff --git a/.github/workflows/linux_conda.yml b/.github/workflows/linux_conda.yml index d7de45605..7b7104f76 100644 --- a/.github/workflows/linux_conda.yml +++ b/.github/workflows/linux_conda.yml @@ -41,7 +41,7 @@ jobs: source ./get_minimal_commands.sh pip install .[test] name: 'Install dependencies' - - run: pip install git+https://github.com/mne-tools/mne-python@main + - run: pip install git+https://github.com/tsbinns/mne-python/tree/complex_spectrum - run: pip install -e . - run: | which mne diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 803a28991..34685c219 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -60,7 +60,7 @@ jobs: run: pip install --upgrade mne - name: Install MNE (main) if: matrix.mne-version == 'mne-main' - run: pip install git+https://github.com/mne-tools/mne-python@main + run: pip install git+https://github.com/tsbinns/mne-python/tree/complex_spectrum - run: python -c "import mne; print(mne.datasets.testing.data_path(verbose=True))" if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' - name: Display versions and environment information diff --git a/examples/sensor_connectivity.py b/examples/sensor_connectivity.py index a7c9f86da..15048a987 100644 --- a/examples/sensor_connectivity.py +++ b/examples/sensor_connectivity.py @@ -12,6 +12,8 @@ # # License: BSD (3-clause) +# %% + import os.path as op import mne @@ -22,7 +24,8 @@ print(__doc__) -############################################################################### +# %% + # Set parameters data_path = sample.data_path() raw_fname = op.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw.fif") @@ -41,7 +44,7 @@ ) # Create epochs for the visual condition -event_id, tmin, tmax = 3, -0.2, 1.5 # need a long enough epoch for 5 cycles +event_id, tmin, tmax = 3, -0.2, 1.5 # want a long enough epoch for 5 cycles epochs = mne.Epochs( raw, events, @@ -52,24 +55,18 @@ baseline=(None, 0), reject=dict(grad=4000e-13, eog=150e-6), ) +epochs.load_data().pick("grad") # just keep MEG and no EOG now -# Compute connectivity for band containing the evoked response. -# We exclude the baseline period: -fmin, fmax = 4.0, 9.0 -sfreq = raw.info["sfreq"] # the sampling frequency +# Compute Fourier coefficients for the epochs (returns an EpochsSpectrum object) +# (storing Fourier coefficients in EpochsSpectrum objects requires MNE >= 1.8) tmin = 0.0 # exclude the baseline period -epochs.load_data().pick_types(meg="grad") # just keep MEG and no EOG now +spectrum = epochs.compute_psd(method="multitaper", tmin=tmin, output="complex") + +# Compute connectivity for the frequency band containing the evoked response +# (passing EpochsSpectrum objects as data requires MNE-Connectivity >= 0.8) +fmin, fmax = 4.0, 9.0 con = spectral_connectivity_epochs( - epochs, - method="pli", - mode="multitaper", - sfreq=sfreq, - fmin=fmin, - fmax=fmax, - faverage=True, - tmin=tmin, - mt_adaptive=False, - n_jobs=1, + data=spectrum, method="pli", fmin=fmin, fmax=fmax, faverage=True, n_jobs=1 ) # Now, visualize the connectivity in 3D: diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 26436d9b2..7e951a76b 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -19,6 +19,11 @@ _psd_from_mt, _psd_from_mt_adaptive, ) +from mne.time_frequency.spectrum import ( + BaseSpectrum, + EpochsSpectrum, + EpochsSpectrumArray, +) from mne.time_frequency.tfr import cwt, morlet from mne.utils import _arange_div, _check_option, _time_mask, logger, verbose, warn @@ -35,6 +40,69 @@ ) +def _check_times(data, sfreq, times, tmin, tmax): + # get the data size and time scale + n_signals, _, times_in, warn_times = _get_and_verify_data_sizes( + data=data, sfreq=sfreq, times=times + ) + n_times_in = len(times_in) # XXX: Why not use times returned from above func? + + if tmin is not None and tmin < times_in[0]: + warn( + f"start time tmin={tmin:.2f} s outside of the time scope of the data " + f"[{times_in[0]:.2f} s, {times_in[-1]:.2f} s]" + ) + if tmax is not None and tmax > times_in[-1]: + warn( + f"stop time tmax={tmax:.2f} s outside of the time scope of the data " + f"[{times_in[0]:.2f} s, {times_in[-1]:.2f} s]" + ) + + mask = _time_mask(times_in, tmin, tmax, sfreq=sfreq) + tmin_idx, tmax_idx = np.where(mask)[0][[0, -1]] + tmax_idx += 1 + tmin_true = times_in[tmin_idx] + tmax_true = times_in[tmax_idx - 1] # time of last point used + + times = times_in[tmin_idx:tmax_idx] + n_times = len(times) + + logger.info( + f" using t={tmin_true:.3f}s..{tmax_true:.3f}s for estimation ({n_times} " + "points)" + ) + + return ( + n_signals, + times, + n_times, + times_in, + n_times_in, + tmin_idx, + tmax_idx, + warn_times, + ) + + +def _check_freqs(sfreq, fmin, n_times): + # check that fmin corresponds to at least 5 cycles + dur = float(n_times) / sfreq + five_cycle_freq = 5.0 / dur + if len(fmin) == 1 and fmin[0] == -np.inf: + # we use the 5 cycle freq. as default + fmin = np.array([five_cycle_freq]) + else: + if np.any(fmin < five_cycle_freq): + warn( + f"fmin={np.min(fmin):.3f} Hz corresponds to {dur * np.min(fmin):.3f} < " + f"5 cycles based on the epoch length {dur:.3f} sec, need at least " + f"{5.0 / np.min(fmin):.3f} sec epochs or fmin={five_cycle_freq:.3f}. " + "Spectrum estimate will be unreliable." + ) + + return fmin + + def _compute_freqs(n_times, sfreq, cwt_freqs, mode): from scipy.fft import rfftfreq @@ -80,6 +148,7 @@ def _prepare_connectivity( fmin, fmax, sfreq, + freqs, indices, method, mode, @@ -87,105 +156,39 @@ def _prepare_connectivity( n_bands, cwt_freqs, faverage, + spectrum_computed, ): """Check and precompute dimensions of results data.""" first_epoch = epoch_block[0] - # get the data size and time scale - n_signals, n_times_in, times_in, warn_times = _get_and_verify_data_sizes( - first_epoch, sfreq, times=times_in - ) - - n_times_in = len(times_in) - - if tmin is not None and tmin < times_in[0]: - warn( - f"start time tmin={tmin:.2f} s outside of the time scope of the data " - f"[{times_in[0]:.2f} s, {times_in[-1]:.2f} s]" - ) - if tmax is not None and tmax > times_in[-1]: - warn( - f"stop time tmax={tmax:.2f} s outside of the time scope of the data " - f"[{times_in[0]:.2f} s, {times_in[-1]:.2f} s]" - ) - - mask = _time_mask(times_in, tmin, tmax, sfreq=sfreq) - tmin_idx, tmax_idx = np.where(mask)[0][[0, -1]] - tmax_idx += 1 - tmin_true = times_in[tmin_idx] - tmax_true = times_in[tmax_idx - 1] # time of last point used - - times = times_in[tmin_idx:tmax_idx] - n_times = len(times) - - if any(this_method in _multivariate_methods for this_method in method): - multivariate_con = True - else: - multivariate_con = False - - if indices is None: - if multivariate_con: - if any(this_method in _gc_methods for this_method in method): - raise ValueError( - "indices must be specified when computing Granger causality, as " - "all-to-all connectivity is not supported" - ) - else: - logger.info("using all indices for multivariate connectivity") - # indices expected to be a masked array, even if not ragged - indices_use = ( - np.arange(n_signals, dtype=int)[np.newaxis, :], - np.arange(n_signals, dtype=int)[np.newaxis, :], - ) - indices_use = np.ma.masked_array(indices_use, mask=False, fill_value=-1) - else: - logger.info("only using indices for lower-triangular matrix") - # only compute r for lower-triangular region - indices_use = np.tril_indices(n_signals, -1) - else: - if multivariate_con: - # pad ragged indices and mask the invalid entries - indices_use = _check_multivariate_indices(indices, n_signals) - if any(this_method in _gc_methods for this_method in method): - for seed, target in zip(indices_use[0], indices_use[1]): - intersection = np.intersect1d( - seed.compressed(), target.compressed() - ) - if intersection.size > 0: - raise ValueError( - "seed and target indices must not intersect when computing " - "Granger causality" - ) - else: - indices_use = check_indices(indices) - - # number of connectivities to compute - n_cons = len(indices_use[0]) - - logger.info(f" computing connectivity for {n_cons} connections") - logger.info( - f" using t={tmin_true:.3f}s..{tmax_true:.3f}s for estimation ({n_times} " - "points)" - ) - - # check that fmin corresponds to at least 5 cycles - dur = float(n_times) / sfreq - five_cycle_freq = 5.0 / dur - if len(fmin) == 1 and fmin[0] == -np.inf: - # we use the 5 cycle freq. as default - fmin = np.array([five_cycle_freq]) + # Sort times and freqs + if spectrum_computed: + n_signals = first_epoch[0].shape[0] + times = None + n_times = None + times_in = None + n_times_in = None + tmin_idx = None + tmax_idx = None + warn_times = False else: - if np.any(fmin < five_cycle_freq): - warn( - f"fmin={np.min(fmin):.3f} Hz corresponds to {dur * np.min(fmin):.3f} < " - f"5 cycles based on the epoch length {dur:.3f} sec, need at least " - f"{5.0 / np.min(fmin):.3f} sec epochs or fmin={five_cycle_freq:.3f}. " - "Spectrum estimate will be unreliable." - ) - - # compute frequencies to analyze based on number of samples, - # sampling rate, specified wavelet frequencies and mode - freqs = _compute_freqs(n_times, sfreq, cwt_freqs, mode) + ( + n_signals, + times, + n_times, + times_in, + n_times_in, + tmin_idx, + tmax_idx, + warn_times, + ) = _check_times( + data=first_epoch, sfreq=sfreq, times=times_in, tmin=tmin, tmax=tmax + ) + # check that fmin corresponds to at least 5 cycles + fmin = _check_freqs(sfreq=sfreq, fmin=fmin, n_times=n_times) + # compute frequencies to analyze based on number of samples, sampling rate, + # specified wavelet frequencies, and mode + freqs = _compute_freqs(n_times, sfreq, cwt_freqs, mode) # compute the mask based on specified min/max and decimation factor freq_mask = _compute_freq_mask(freqs, fmin, fmax, fskip) @@ -223,6 +226,51 @@ def _prepare_connectivity( if faverage: logger.info(" connectivity scores will be averaged for each band") + # Sort indices + multivariate_con = any( + this_method in _multivariate_methods for this_method in method + ) + + if indices is None: + if multivariate_con: + if any(this_method in _gc_methods for this_method in method): + raise ValueError( + "indices must be specified when computing Granger causality, as " + "all-to-all connectivity is not supported" + ) + logger.info("using all indices for multivariate connectivity") + # indices expected to be a masked array, even if not ragged + indices_use = ( + np.arange(n_signals, dtype=int)[np.newaxis, :], + np.arange(n_signals, dtype=int)[np.newaxis, :], + ) + indices_use = np.ma.masked_array(indices_use, mask=False, fill_value=-1) + else: + logger.info("only using indices for lower-triangular matrix") + # only compute r for lower-triangular region + indices_use = np.tril_indices(n_signals, -1) + else: + if multivariate_con: + # pad ragged indices and mask the invalid entries + indices_use = _check_multivariate_indices(indices, n_signals) + if any(this_method in _gc_methods for this_method in method): + for seed, target in zip(indices_use[0], indices_use[1]): + intersection = np.intersect1d( + seed.compressed(), target.compressed() + ) + if intersection.size > 0: + raise ValueError( + "seed and target indices must not intersect when computing " + "Granger causality" + ) + else: + indices_use = check_indices(indices) + + # number of connections to compute + n_cons = len(indices_use[0]) + + logger.info(f" computing connectivity for {n_cons} connections") + return ( n_cons, times, @@ -255,7 +303,7 @@ def _assemble_spectral_params( freq_mask, ): """Prepare time-frequency decomposition.""" - spectral_params = dict(eigvals=None, window_fun=None, wavelets=None) + spectral_params = dict(eigvals=None, window_fun=None, wavelets=None, weights=None) n_tapers = None n_times_spectrum = 0 if mode == "multitaper": @@ -313,6 +361,79 @@ def compute_con(self, con_idx, n_epochs): ######################################################################## +def _compute_spectra( + data, + sfreq, + mode, + sig_idx, + tmin_idx, + tmax_idx, + mt_adaptive, + eigvals, + wavelets, + window_fun, + freq_mask, + accumulate_psd, +): + x_t = list() + this_psd = list() + for this_data in data: + if mode in ("multitaper", "fourier"): + if isinstance(this_data, _BaseSourceEstimate): + _mt_spectra_partial = partial(_mt_spectra, dpss=window_fun, sfreq=sfreq) + this_x_t = this_data.transform_data( + _mt_spectra_partial, + idx=sig_idx, + tmin_idx=tmin_idx, + tmax_idx=tmax_idx, + ) + else: + this_x_t, _ = _mt_spectra( + this_data[sig_idx, tmin_idx:tmax_idx], window_fun, sfreq + ) + + if mt_adaptive: + # compute PSD and adaptive weights + _this_psd, weights = _psd_from_mt_adaptive( + this_x_t, eigvals, freq_mask, return_weights=True + ) + + # only keep freqs of interest + this_x_t = this_x_t[:, :, freq_mask] + else: + # do not use adaptive weights + this_x_t = this_x_t[:, :, freq_mask] + if mode == "multitaper": + weights = np.sqrt(eigvals)[np.newaxis, :, np.newaxis] + else: + # hack to so we can sum over axis=-2 + weights = np.array([1.0])[:, None, None] + + if accumulate_psd: + _this_psd = _psd_from_mt(this_x_t, weights) + else: # mode == 'cwt_morlet' + weights = None + if isinstance(this_data, _BaseSourceEstimate): + cwt_partial = partial(cwt, Ws=wavelets, use_fft=True, mode="same") + this_x_t = this_data.transform_data( + cwt_partial, idx=sig_idx, tmin_idx=tmin_idx, tmax_idx=tmax_idx + ) + else: + this_x_t = cwt( + this_data[sig_idx, tmin_idx:tmax_idx], + wavelets, + use_fft=True, + mode="same", + ) + _this_psd = (this_x_t * this_x_t.conj()).real + + x_t.append(this_x_t) + if accumulate_psd: + this_psd.append(_this_psd) + + return x_t, this_psd, weights + + def _epoch_spectral_connectivity( data, sig_idx, @@ -323,6 +444,7 @@ def _epoch_spectral_connectivity( mode, window_fun, eigvals, + weights, wavelets, freq_mask, mt_adaptive, @@ -338,6 +460,7 @@ def _epoch_spectral_connectivity( n_times, gc_n_lags, n_components, + spectrum_computed, accumulate_inplace=True, ): """Estimate connectivity for one epoch (see spectral_connectivity).""" @@ -368,7 +491,7 @@ def _epoch_spectral_connectivity( ) ) else: - # if it's a coherence method + # if it's a coherency-based method con_methods.append( mtype( n_signals_use, @@ -387,60 +510,30 @@ def _epoch_spectral_connectivity( sig_idx = slice(None, None) # compute tapered spectra - x_t = list() - this_psd = list() - for this_data in data: - if mode in ("multitaper", "fourier"): - if isinstance(this_data, _BaseSourceEstimate): - _mt_spectra_partial = partial(_mt_spectra, dpss=window_fun, sfreq=sfreq) - this_x_t = this_data.transform_data( - _mt_spectra_partial, - idx=sig_idx, - tmin_idx=tmin_idx, - tmax_idx=tmax_idx, - ) - else: - this_x_t, _ = _mt_spectra( - this_data[sig_idx, tmin_idx:tmax_idx], window_fun, sfreq - ) - - if mt_adaptive: - # compute PSD and adaptive weights - _this_psd, weights = _psd_from_mt_adaptive( - this_x_t, eigvals, freq_mask, return_weights=True - ) - - # only keep freqs of interest - this_x_t = this_x_t[:, :, freq_mask] - else: - # do not use adaptive weights - this_x_t = this_x_t[:, :, freq_mask] - if mode == "multitaper": - weights = np.sqrt(eigvals)[np.newaxis, :, np.newaxis] - else: - # hack to so we can sum over axis=-2 - weights = np.array([1.0])[:, None, None] - - if accumulate_psd: - _this_psd = _psd_from_mt(this_x_t, weights) - else: # mode == 'cwt_morlet' - if isinstance(this_data, _BaseSourceEstimate): - cwt_partial = partial(cwt, Ws=wavelets, use_fft=True, mode="same") - this_x_t = this_data.transform_data( - cwt_partial, idx=sig_idx, tmin_idx=tmin_idx, tmax_idx=tmax_idx - ) - else: - this_x_t = cwt( - this_data[sig_idx, tmin_idx:tmax_idx], - wavelets, - use_fft=True, - mode="same", - ) - _this_psd = (this_x_t * this_x_t.conj()).real - - x_t.append(this_x_t) + if spectrum_computed: # use existing spectral info + # XXX: Will need to distinguish time-resolved spectra here if support added + # Select signals & freqs of interest (flexible indexing for optional tapers dim) + x_t = np.array(data)[:, sig_idx, ..., freq_mask] + if weights is None: # also assumes no tapers dim + x_t = np.expand_dims(x_t, axis=2) # CSD construction expects a tapers dim + weights = np.array([1.0])[:, None, None] if accumulate_psd: - this_psd.append(_this_psd) + this_psd = _psd_from_mt(x_t, weights) + else: # compute spectral info from scratch + x_t, this_psd, weights = _compute_spectra( + data=data, + sfreq=sfreq, + mode=mode, + sig_idx=sig_idx, + tmin_idx=tmin_idx, + tmax_idx=tmax_idx, + mt_adaptive=mt_adaptive, + eigvals=eigvals, + wavelets=wavelets, + window_fun=window_fun, + freq_mask=freq_mask, + accumulate_psd=accumulate_psd, + ) x_t = np.concatenate(x_t, axis=0) if accumulate_psd: @@ -634,14 +727,28 @@ def spectral_connectivity_epochs( Parameters ---------- - data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs - The data from which to compute connectivity. Note that it is also - possible to combine multiple signals by providing a list of tuples, - e.g., data = [(arr_0, stc_0), (arr_1, stc_1), (arr_2, stc_2)], - corresponds to 3 epochs, and arr_* could be an array with the same - number of time points as stc_*. The array-like object can also - be a list/generator of array, shape =(n_signals, n_times), - or a list/generator of SourceEstimate or VolSourceEstimate objects. + data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs | ~mne.time_frequency.EpochsSpectrum + The data from which to compute connectivity. Can be epoched timeseries data as + an `array` or `~mne.Epochs` object, or Fourier coefficients for each epoch as an + `~mne.time_frequency.EpochsSpectrum` object. If timeseries data, the spectral + information will be computed according to the spectral estimation mode (see the + ``mode`` parameter). If an `~mne.time_frequency.EpochsSpectrum` object, this + spectral information will be used and the ``mode`` parameter will be ignored. + + Note that it is also possible to combine multiple timeseries signals by + providing a list of tuples, e.g.: :: + + data = [(arr_0, stc_0), (arr_1, stc_1), (arr_2, stc_2)] + + which corresponds to 3 epochs where ``arr_*`` is an array with the same number + of time points as ``stc_*``. Data can also be a `list`/:term:`generator` of + arrays, ``shape (n_signals, n_times)``, or a `list`/:term:`generator` of + `~mne.SourceEstimate` or `~mne.VolSourceEstimate` objects. + + .. versionchanged:: 0.8 + Fourier coefficients stored in an `~mne.time_frequency.EpochsSpectrum` or + `~mne.time_frequency.EpochsSpectrumArray` object can also be passed in as + data. Storing Fourier coefficients requires ``mne >= 1.8``. %(names)s method : str | list of str Connectivity measure(s) to compute. These can be ``['coh', 'cohy', @@ -677,11 +784,11 @@ def spectral_connectivity_epochs( connections between all channels are computed, unless a Granger causality method is called, in which case an error is raised. sfreq : float - The sampling frequency. Required if data is not - :class:`Epochs `. + The sampling frequency. Required if data is an :term:`array-like`. mode : str Spectrum estimation mode can be either: 'multitaper', 'fourier', or - 'cwt_morlet'. + 'cwt_morlet'. Ignored if ``data`` is an `~mne.time_frequency.EpochsSpectrum` + object. fmin : float | tuple of float The lower frequency of interest. Multiple bands are defined using a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. @@ -696,29 +803,33 @@ def spectral_connectivity_epochs( the output freqs will be a list with arrays of the frequencies that were averaged. tmin : float | None - Time to start connectivity estimation. Note: when "data" is an array, - the first sample is assumed to be at time 0. For other types - (Epochs, etc.), the time information contained in the object is used - to compute the time indices. + Time to start connectivity estimation. Note: when ``data`` is an `array`, the + first sample is assumed to be at time 0. For `~mne.Epochs`, the time information + contained in the object is used to compute the time indices. Ignored if ``data`` + is an `~mne.time_frequency.EpochsSpectrum` object. tmax : float | None - Time to end connectivity estimation. Note: when "data" is an array, - the first sample is assumed to be at time 0. For other types - (Epochs, etc.), the time information contained in the object is used - to compute the time indices. + Time to end connectivity estimation. Note: when ``data`` is an `array`, the + first sample is assumed to be at time 0. For `~mne.Epochs`, the time information + contained in the object is used to compute the time indices. Ignored if ``data`` + is an `~mne.time_frequency.EpochsSpectrum` object. mt_bandwidth : float | None - The bandwidth of the multitaper windowing function in Hz. - Only used in 'multitaper' mode. + The bandwidth of the multitaper windowing function in Hz. Only used in + 'multitaper' mode. Ignored if ``data`` is an + `~mne.time_frequency.EpochsSpectrum` object. mt_adaptive : bool - Use adaptive weights to combine the tapered spectra into PSD. - Only used in 'multitaper' mode. + Use adaptive weights to combine the tapered spectra into PSD. Only used in + 'multitaper' mode. Ignored if ``data`` is an + `~mne.time_frequency.EpochsSpectrum` object. mt_low_bias : bool - Only use tapers with more than 90 percent spectral concentration - within bandwidth. Only used in 'multitaper' mode. + Only use tapers with more than 90 percent spectral concentration within + bandwidth. Only used in 'multitaper' mode. Ignored if ``data`` is an + `~mne.time_frequency.EpochsSpectrum` object. cwt_freqs : array - Array of frequencies of interest. Only used in 'cwt_morlet' mode. + Array of frequencies of interest. Only used in 'cwt_morlet' mode. Ignored if + ``data`` is an `~mne.time_frequency.EpochsSpectrum` object. cwt_n_cycles : float | array of float - Number of cycles. Fixed number or one per frequency. Only used in - 'cwt_morlet' mode. + Number of cycles. Fixed number or one per frequency. Only used in 'cwt_morlet' + mode. Ignored if ``data`` is an `~mne.time_frequency.EpochsSpectrum` object. gc_n_lags : int Number of lags to use for the vector autoregressive model when computing Granger causality. Higher values increase computational cost, @@ -737,7 +848,7 @@ def spectral_connectivity_epochs( minimum rank of the seeds and targets is extracted (see the ``rank`` parameter). Only used if ``method`` contains any of ``['cacoh', 'mic']``. - .. versionadded:: 0.8.0 + .. versionadded:: 0.8 block_size : int How many connections to compute at once (higher numbers are faster but require more memory). @@ -941,7 +1052,7 @@ def spectral_connectivity_epochs( References ---------- .. footbibliography:: - """ + """ # noqa: E501 if n_jobs != 1: parallel, my_epoch_spectral_connectivity, n_jobs = parallel_func( _epoch_spectral_connectivity, n_jobs, verbose=verbose @@ -983,11 +1094,15 @@ def spectral_connectivity_epochs( # handle connectivity estimators (con_method_types, n_methods, accumulate_psd) = _check_estimators(method) + times_in = None events = None event_id = None - if isinstance(data, BaseEpochs): + freqs = None + weights = None + metadata = None + spectrum_computed = False + if isinstance(data, (BaseEpochs, EpochsSpectrum, EpochsSpectrumArray)): names = data.ch_names - times_in = data.times # input times for Epochs input type sfreq = data.info["sfreq"] events = data.events @@ -1006,11 +1121,33 @@ def spectral_connectivity_epochs( if hasattr(data, "annotations") and not annots_in_metadata: data.add_annotations_to_metadata(overwrite=True) metadata = data.metadata - else: - times_in = None - metadata = None - if sfreq is None: - raise ValueError("Sampling frequency (sfreq) is required with array input.") + + if isinstance(data, (EpochsSpectrum, EpochsSpectrumArray)): + # XXX: Will need to be updated if new Spectrum methods are added + if not np.iscomplexobj(data.get_data()): + raise TypeError( + "if `data` is an EpochsSpectrum object, it must contain " + "complex-valued Fourier coefficients, such as that returned from " + "Epochs.compute_psd(output='complex')" + ) + if "segment" in data._dims: + raise ValueError( + "`data` cannot contain Fourier coefficients for individual segments" + ) + if isinstance(data, EpochsSpectrum): # mode can be read mode from Spectrum + mode = data.method + mode = "fourier" if mode == "welch" else mode + else: # spectral method is "unknown", so take mode from data dimensions + # Currently, actual mode doesn't matter as long as we handle tapers and + # their weights in the same way as for multitaper spectra + mode = "multitaper" if "taper" in data._dims else "fourier" + spectrum_computed = True + freqs = data.freqs + weights = data.weights + else: + times_in = data.times # input times for Epochs input type + elif sfreq is None: + raise ValueError("Sampling frequency (sfreq) is required with array input.") # loop over data; it could be a generator that returns # (n_signals x n_times) arrays or SourceEstimates @@ -1044,6 +1181,7 @@ def spectral_connectivity_epochs( fmin=fmin, fmax=fmax, sfreq=sfreq, + freqs=freqs, indices=indices, method=method, mode=mode, @@ -1051,6 +1189,7 @@ def spectral_connectivity_epochs( n_bands=n_bands, cwt_freqs=cwt_freqs, faverage=faverage, + spectrum_computed=spectrum_computed, ) # check rank input and compute data ranks if necessary @@ -1073,23 +1212,27 @@ def spectral_connectivity_epochs( indices = (indices_use[0].copy(), indices_use[1].copy()) # get the window function, wavelets, etc for different modes - ( - spectral_params, - mt_adaptive, - n_times_spectrum, - n_tapers, - ) = _assemble_spectral_params( - mode=mode, - n_times=n_times, - mt_adaptive=mt_adaptive, - mt_bandwidth=mt_bandwidth, - sfreq=sfreq, - mt_low_bias=mt_low_bias, - cwt_n_cycles=cwt_n_cycles, - cwt_freqs=cwt_freqs, - freqs=freqs, - freq_mask=freq_mask, - ) + if not spectrum_computed: + spectral_params, mt_adaptive, n_times_spectrum, n_tapers = ( + _assemble_spectral_params( + mode=mode, + n_times=n_times, + mt_adaptive=mt_adaptive, + mt_bandwidth=mt_bandwidth, + sfreq=sfreq, + mt_low_bias=mt_low_bias, + cwt_n_cycles=cwt_n_cycles, + cwt_freqs=cwt_freqs, + freqs=freqs, + freq_mask=freq_mask, + ) + ) + else: + spectral_params = dict( + eigvals=None, window_fun=None, wavelets=None, weights=weights + ) + n_times_spectrum = 0 + n_tapers = None if weights is None else weights.size # unique signals for which we actually need to compute PSD etc. if multivariate_con: @@ -1142,15 +1285,16 @@ def spectral_connectivity_epochs( logger.info(f" the following metrics will be computed: {metrics_str}") # check dimensions and time scale - for this_epoch in epoch_block: - _, _, _, warn_times = _get_and_verify_data_sizes( - this_epoch, - sfreq, - n_signals, - n_times_in, - times_in, - warn_times=warn_times, - ) + if not spectrum_computed: # XXX: Can we assume upstream checks sufficient? + for this_epoch in epoch_block: + _, _, _, warn_times = _get_and_verify_data_sizes( + this_epoch, + sfreq, + n_signals, + n_times_in, + times_in, + warn_times=warn_times, + ) call_params = dict( sig_idx=sig_idx, @@ -1173,6 +1317,7 @@ def spectral_connectivity_epochs( n_times=n_times, gc_n_lags=gc_n_lags, n_components=n_components, + spectrum_computed=spectrum_computed, accumulate_inplace=True if n_jobs == 1 else False, ) call_params.update(**spectral_params) @@ -1320,7 +1465,7 @@ def spectral_connectivity_epochs( freqs=freqs, method=_method, n_nodes=n_nodes, - spec_method=mode, + spec_method=mode if not isinstance(data, BaseSpectrum) else data.method, indices=indices, n_epochs_used=n_epochs, freqs_used=freqs_used, diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 6f913b4f0..e389d41c3 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -2,17 +2,20 @@ import os import platform +import mne import numpy as np import pandas as pd import pytest from mne import EpochsArray, SourceEstimate, create_info from mne.filter import filter_data from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_less +from packaging.version import Version from mne_connectivity import ( SpectralConnectivity, make_signals_in_freq_bands, read_connectivity, + seed_target_indices, spectral_connectivity_epochs, spectral_connectivity_time, ) @@ -467,6 +470,96 @@ def test_spectral_connectivity(method, mode): assert out_lens[0] == 10 +# Fourier coeffs in Spectrum objects added in MNE v1.8.0 +@pytest.mark.skipif( + Version(mne.__version__) <= Version("1.7.1"), reason="Requires MNE v1.8.0 or higher" +) +@pytest.mark.parametrize("mode", ["multitaper", "fourier"]) +def test_spectral_connectivity_epochs_spectrum_input(mode): + """Test spec_conn_epochs works with EpochsSpectrum data as input.""" + # Simulation parameters & data generation + sfreq = 100.0 # Hz + n_seeds = 2 + n_targets = 2 + fband = (15, 20) # Hz + n_epochs = 30 + n_times = 200 # samples + trans_bandwidth = 1.0 # Hz + delay = 10 # samples + + data = make_signals_in_freq_bands( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=fband, + n_epochs=n_epochs, + n_times=n_times, + sfreq=sfreq, + trans_bandwidth=trans_bandwidth, + snr=0.7, + connection_delay=delay, + rng_seed=44, + ) + + indices = seed_target_indices( + seeds=np.arange(n_seeds), targets=np.arange(n_targets) + n_seeds + ) + + # Compute Fourier coefficients + coeffs = data.compute_psd( + method="welch" if mode == "fourier" else mode, output="complex" + ) + + # Compute connectivity (just coherence) + con = spectral_connectivity_epochs(data=coeffs, method="coh", indices=indices) + + # Check connectivity from Epochs and Spectrum are equivalent; + # Works for multitaper, but Welch of Spectrum and Fourier of spec_conn are slightly + # off (max. abs. diff. ~0.006) even when what should be identical settings are used + if mode == "multitaper": + con_from_epochs = spectral_connectivity_epochs( + data=data, method="coh", indices=indices + ) + # spec_conn_epochs excludes freqs without at least 5 cycles, but not Spectrum + fstart = con.freqs.index(con_from_epochs.freqs[0]) + assert_allclose(con.get_data()[:, fstart:], con_from_epochs.get_data()) + + # Check connectivity values are as expected + freqs = np.array(con.freqs) + freqs_con = (freqs >= fband[0]) & (freqs <= fband[1]) + freqs_noise = (freqs < fband[0] - trans_bandwidth * 2) | ( + freqs > fband[1] + trans_bandwidth * 2 + ) + + assert_array_less(0.6, con.get_data()[:, freqs_con]) + assert_array_less(con.get_data()[:, freqs_noise], 0.2) + + +# TODO: Add general test for error catching for spec_conn_epochs +# Fourier coeffs in Spectrum objects added in MNE v1.8.0 +@pytest.mark.skipif( + Version(mne.__version__) <= Version("1.7.1"), reason="Requires MNE v1.8.0 or higher" +) +def test_spectral_connectivity_epochs_spectrum_input_error_catch(): + """Test spec_conn_epochs catches error with EpochsSpectrum data as input.""" + # Generate data + rng = np.random.default_rng(44) + n_epochs, n_chans, n_times = (5, 2, 50) + sfreq = 50 + data = rng.random((n_epochs, n_chans, n_times)) + info = create_info(ch_names=n_chans, sfreq=sfreq, ch_types="eeg") + data = EpochsArray(data=data, info=info) + + # Test not Fourier coefficients caught + with pytest.raises(TypeError, match="must contain complex-valued Fourier coeff"): + spectrum = data.compute_psd(output="power") + spectral_connectivity_epochs(data=spectrum) + + # Test unaggregated segments caught + with pytest.raises(ValueError, match=r"cannot contain Fourier coeff.*segments"): + spectrum = data.compute_psd(method="welch", average=False, output="complex") + spectral_connectivity_epochs(data=spectrum) + + _gc_marks = [] if platform.system() == "Darwin" and platform.processor() == "arm": _gc_marks.extend( @@ -527,7 +620,7 @@ def test_spectral_connectivity_epochs_multivariate(method, n_components): freqs = np.array(con.freqs) freqs_con = (freqs >= fstart) & (freqs <= fend) freqs_noise = (freqs < fstart - trans_bandwidth * 2) | ( - freqs > fend + -trans_bandwidth * 2 + freqs > fend + trans_bandwidth * 2 ) # Check connectivity scores are in expected range diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 55f51c5be..6e43a40f2 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -165,6 +165,8 @@ def spectral_connectivity_time( ``n_components`` must be <= 3. If `None`, the number of components equal to the minimum rank of the seeds and targets is extracted (see the ``rank`` parameter). Only used if ``method`` contains any of ``['cacoh', 'mic']``. + + .. versionadded:: 0.8 decim : int To reduce memory usage, decimation factor after time-frequency decomposition. Returns ``tfr[…, ::decim]``. From b517e79a15d88d810a73dbff4aac080ef781ee87 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 1 Aug 2024 16:18:10 +0200 Subject: [PATCH 02/18] Fix CircleCI --- .circleci/config.yml | 2 +- .github/workflows/linux_conda.yml | 2 +- .github/workflows/unit_tests.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 99fcfa338..f78ff8422 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -77,7 +77,7 @@ jobs: - run: name: Get Python running and install dependencies command: | - pip install git+https://github.com/tsbinns/mne-python/tree/complex_spectrum + pip install git+https://github.com/tsbinns/mne-python@complex_spectrum curl https://raw.githubusercontent.com/mne-tools/mne-python/main/tools/circleci_dependencies.sh -o circleci_dependencies.sh chmod +x circleci_dependencies.sh ./circleci_dependencies.sh diff --git a/.github/workflows/linux_conda.yml b/.github/workflows/linux_conda.yml index 7b7104f76..9b985af66 100644 --- a/.github/workflows/linux_conda.yml +++ b/.github/workflows/linux_conda.yml @@ -41,7 +41,7 @@ jobs: source ./get_minimal_commands.sh pip install .[test] name: 'Install dependencies' - - run: pip install git+https://github.com/tsbinns/mne-python/tree/complex_spectrum + - run: pip install git+https://github.com/tsbinns/mne-python@complex_spectrum - run: pip install -e . - run: | which mne diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 34685c219..0d591f0ca 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -60,7 +60,7 @@ jobs: run: pip install --upgrade mne - name: Install MNE (main) if: matrix.mne-version == 'mne-main' - run: pip install git+https://github.com/tsbinns/mne-python/tree/complex_spectrum + run: pip install git+https://github.com/tsbinns/mne-python@complex_spectrum - run: python -c "import mne; print(mne.datasets.testing.data_path(verbose=True))" if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' - name: Display versions and environment information From dccecdc51318a2bcd17e04e2657891d756022a8b Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 1 Aug 2024 16:18:27 +0200 Subject: [PATCH 03/18] Add reminder to change version --- mne_connectivity/spectral/tests/test_spectral.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index e389d41c3..927877de5 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -471,6 +471,7 @@ def test_spectral_connectivity(method, mode): # Fourier coeffs in Spectrum objects added in MNE v1.8.0 +# FIXME: Update to reference MNE 1.8 instead of 1.7.1 when 1.8 released @pytest.mark.skipif( Version(mne.__version__) <= Version("1.7.1"), reason="Requires MNE v1.8.0 or higher" ) @@ -536,6 +537,7 @@ def test_spectral_connectivity_epochs_spectrum_input(mode): # TODO: Add general test for error catching for spec_conn_epochs # Fourier coeffs in Spectrum objects added in MNE v1.8.0 +# FIXME: Update to reference MNE 1.8 instead of 1.7.1 when 1.8 released @pytest.mark.skipif( Version(mne.__version__) <= Version("1.7.1"), reason="Requires MNE v1.8.0 or higher" ) From 45bf496fc291564eb8fd7765d67182cceef72b9e Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Thu, 1 Aug 2024 16:46:22 +0200 Subject: [PATCH 04/18] Update Spectrum skips Co-authored-by: Eric Larson --- mne_connectivity/spectral/tests/test_spectral.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 927877de5..58743351c 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -536,10 +536,7 @@ def test_spectral_connectivity_epochs_spectrum_input(mode): # TODO: Add general test for error catching for spec_conn_epochs -# Fourier coeffs in Spectrum objects added in MNE v1.8.0 -# FIXME: Update to reference MNE 1.8 instead of 1.7.1 when 1.8 released -@pytest.mark.skipif( - Version(mne.__version__) <= Version("1.7.1"), reason="Requires MNE v1.8.0 or higher" +@pytest.mark.skipif(not check_version("mne", "1.8"), reason="Requires MNE v1.8.0 or higher" ) def test_spectral_connectivity_epochs_spectrum_input_error_catch(): """Test spec_conn_epochs catches error with EpochsSpectrum data as input.""" From 0b8790b98a4d1306535e2a84cd38f025f676a834 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 1 Aug 2024 17:10:31 +0200 Subject: [PATCH 05/18] Fix broken tests and version checking --- .../spectral/tests/test_spectral.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 58743351c..9631b83bc 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -2,14 +2,13 @@ import os import platform -import mne import numpy as np import pandas as pd import pytest from mne import EpochsArray, SourceEstimate, create_info from mne.filter import filter_data +from mne.utils import check_version from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_less -from packaging.version import Version from mne_connectivity import ( SpectralConnectivity, @@ -471,9 +470,8 @@ def test_spectral_connectivity(method, mode): # Fourier coeffs in Spectrum objects added in MNE v1.8.0 -# FIXME: Update to reference MNE 1.8 instead of 1.7.1 when 1.8 released @pytest.mark.skipif( - Version(mne.__version__) <= Version("1.7.1"), reason="Requires MNE v1.8.0 or higher" + not check_version("mne", "1.8"), reason="Requires MNE v1.8.0 or higher" ) @pytest.mark.parametrize("mode", ["multitaper", "fourier"]) def test_spectral_connectivity_epochs_spectrum_input(mode): @@ -496,7 +494,7 @@ def test_spectral_connectivity_epochs_spectrum_input(mode): n_times=n_times, sfreq=sfreq, trans_bandwidth=trans_bandwidth, - snr=0.7, + snr=0.5, connection_delay=delay, rng_seed=44, ) @@ -531,12 +529,20 @@ def test_spectral_connectivity_epochs_spectrum_input(mode): freqs > fband[1] + trans_bandwidth * 2 ) - assert_array_less(0.6, con.get_data()[:, freqs_con]) - assert_array_less(con.get_data()[:, freqs_noise], 0.2) + if mode == "multitaper": # lower baseline for multitaper + con_thresh = (0.1, 0.3) + else: # higher baseline for Welch/Fourier + con_thresh = (0.2, 0.4) + + # check freqs of simulated interaction show strong connectivity + assert_array_less(con_thresh[1], con.get_data()[:, freqs_con].mean()) + # check freqs of no simulated interaction (just noise) show weak connectivity + assert_array_less(con.get_data()[:, freqs_noise].mean(), con_thresh[0]) # TODO: Add general test for error catching for spec_conn_epochs -@pytest.mark.skipif(not check_version("mne", "1.8"), reason="Requires MNE v1.8.0 or higher" +@pytest.mark.skipif( + not check_version("mne", "1.8"), reason="Requires MNE v1.8.0 or higher" ) def test_spectral_connectivity_epochs_spectrum_input_error_catch(): """Test spec_conn_epochs catches error with EpochsSpectrum data as input.""" From dac41534a7aab49f8c552a1fecb436a057eb62ce Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 1 Aug 2024 19:51:17 +0200 Subject: [PATCH 06/18] Be explicit with intersphinx roles --- mne_connectivity/spectral/epochs.py | 56 +++++++++++++++-------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 7e951a76b..c252d359a 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -729,11 +729,12 @@ def spectral_connectivity_epochs( ---------- data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs | ~mne.time_frequency.EpochsSpectrum The data from which to compute connectivity. Can be epoched timeseries data as - an `array` or `~mne.Epochs` object, or Fourier coefficients for each epoch as an - `~mne.time_frequency.EpochsSpectrum` object. If timeseries data, the spectral - information will be computed according to the spectral estimation mode (see the - ``mode`` parameter). If an `~mne.time_frequency.EpochsSpectrum` object, this - spectral information will be used and the ``mode`` parameter will be ignored. + an :term:`array-like` or :class:`~mne.Epochs` object, or Fourier coefficients + for each epoch as an :class:`~mne.time_frequency.EpochsSpectrum` object. If + timeseries data, the spectral information will be computed according to the + spectral estimation mode (see the ``mode`` parameter). If an + :class:`~mne.time_frequency.EpochsSpectrum` object, this spectral information + will be used and the ``mode`` parameter will be ignored. Note that it is also possible to combine multiple timeseries signals by providing a list of tuples, e.g.: :: @@ -741,14 +742,14 @@ def spectral_connectivity_epochs( data = [(arr_0, stc_0), (arr_1, stc_1), (arr_2, stc_2)] which corresponds to 3 epochs where ``arr_*`` is an array with the same number - of time points as ``stc_*``. Data can also be a `list`/:term:`generator` of - arrays, ``shape (n_signals, n_times)``, or a `list`/:term:`generator` of - `~mne.SourceEstimate` or `~mne.VolSourceEstimate` objects. + of time points as ``stc_*``. Data can also be a :class:`list`/:term:`generator` + of arrays, ``shape (n_signals, n_times)``, or a :class:`list`/:term:`generator` + of :class:`~mne.SourceEstimate` or :class:`~mne.VolSourceEstimate` objects. .. versionchanged:: 0.8 - Fourier coefficients stored in an `~mne.time_frequency.EpochsSpectrum` or - `~mne.time_frequency.EpochsSpectrumArray` object can also be passed in as - data. Storing Fourier coefficients requires ``mne >= 1.8``. + Fourier coefficients stored in an :class:`~mne.time_frequency.EpochsSpectrum` + or :class:`~mne.time_frequency.EpochsSpectrumArray` object can also be passed + in as data. Storing Fourier coefficients requires ``mne >= 1.8``. %(names)s method : str | list of str Connectivity measure(s) to compute. These can be ``['coh', 'cohy', @@ -787,8 +788,8 @@ def spectral_connectivity_epochs( The sampling frequency. Required if data is an :term:`array-like`. mode : str Spectrum estimation mode can be either: 'multitaper', 'fourier', or - 'cwt_morlet'. Ignored if ``data`` is an `~mne.time_frequency.EpochsSpectrum` - object. + 'cwt_morlet'. Ignored if ``data`` is an + :class:`~mne.time_frequency.EpochsSpectrum` object. fmin : float | tuple of float The lower frequency of interest. Multiple bands are defined using a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. @@ -803,33 +804,36 @@ def spectral_connectivity_epochs( the output freqs will be a list with arrays of the frequencies that were averaged. tmin : float | None - Time to start connectivity estimation. Note: when ``data`` is an `array`, the - first sample is assumed to be at time 0. For `~mne.Epochs`, the time information - contained in the object is used to compute the time indices. Ignored if ``data`` - is an `~mne.time_frequency.EpochsSpectrum` object. + Time to start connectivity estimation. Note: when ``data`` is an + :term:`array-like`, the first sample is assumed to be at time 0. For + :class:`~mne.Epochs`, the time information contained in the object is used to + compute the time indices. Ignored if ``data`` is an + :class:`~mne.time_frequency.EpochsSpectrum` object. tmax : float | None - Time to end connectivity estimation. Note: when ``data`` is an `array`, the - first sample is assumed to be at time 0. For `~mne.Epochs`, the time information - contained in the object is used to compute the time indices. Ignored if ``data`` - is an `~mne.time_frequency.EpochsSpectrum` object. + Time to end connectivity estimation. Note: when ``data`` is an + :term:`array-like`, the first sample is assumed to be at time 0. For + :class:`~mne.Epochs`, the time information contained in the object is used to + compute the time indices. Ignored if ``data`` is an + :class:`~mne.time_frequency.EpochsSpectrum` object. mt_bandwidth : float | None The bandwidth of the multitaper windowing function in Hz. Only used in 'multitaper' mode. Ignored if ``data`` is an - `~mne.time_frequency.EpochsSpectrum` object. + :class:`~mne.time_frequency.EpochsSpectrum` object. mt_adaptive : bool Use adaptive weights to combine the tapered spectra into PSD. Only used in 'multitaper' mode. Ignored if ``data`` is an - `~mne.time_frequency.EpochsSpectrum` object. + :class:`~mne.time_frequency.EpochsSpectrum` object. mt_low_bias : bool Only use tapers with more than 90 percent spectral concentration within bandwidth. Only used in 'multitaper' mode. Ignored if ``data`` is an - `~mne.time_frequency.EpochsSpectrum` object. + :class:`~mne.time_frequency.EpochsSpectrum` object. cwt_freqs : array Array of frequencies of interest. Only used in 'cwt_morlet' mode. Ignored if - ``data`` is an `~mne.time_frequency.EpochsSpectrum` object. + ``data`` is an :class:`~mne.time_frequency.EpochsSpectrum` object. cwt_n_cycles : float | array of float Number of cycles. Fixed number or one per frequency. Only used in 'cwt_morlet' - mode. Ignored if ``data`` is an `~mne.time_frequency.EpochsSpectrum` object. + mode. Ignored if ``data`` is an :class:`~mne.time_frequency.EpochsSpectrum` + object. gc_n_lags : int Number of lags to use for the vector autoregressive model when computing Granger causality. Higher values increase computational cost, From 121b40a76a8ac1d21bee93a2f7db04e3513b4e75 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 1 Aug 2024 19:51:42 +0200 Subject: [PATCH 07/18] Change empty weights contruction --- mne_connectivity/spectral/epochs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index c252d359a..4c475a900 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -406,8 +406,8 @@ def _compute_spectra( if mode == "multitaper": weights = np.sqrt(eigvals)[np.newaxis, :, np.newaxis] else: - # hack to so we can sum over axis=-2 - weights = np.array([1.0])[:, None, None] + # hack to so we can sum over axis=-2 (tapers dim) + weights = np.ones((1, 1, 1)) if accumulate_psd: _this_psd = _psd_from_mt(this_x_t, weights) @@ -516,7 +516,7 @@ def _epoch_spectral_connectivity( x_t = np.array(data)[:, sig_idx, ..., freq_mask] if weights is None: # also assumes no tapers dim x_t = np.expand_dims(x_t, axis=2) # CSD construction expects a tapers dim - weights = np.array([1.0])[:, None, None] + weights = np.ones((1, 1, 1)) if accumulate_psd: this_psd = _psd_from_mt(x_t, weights) else: # compute spectral info from scratch From ffb256d6e975beac282891e8333d5e59ab4fb043 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Fri, 2 Aug 2024 14:39:35 +0200 Subject: [PATCH 08/18] Add surrogate data generation --- doc/api.rst | 3 +- doc/references.bib | 21 ++ examples/surrogate_connectivity.py | 351 ++++++++++++++++++ mne_connectivity/__init__.py | 2 +- mne_connectivity/datasets/__init__.py | 1 + mne_connectivity/datasets/surrogate.py | 141 +++++++ .../datasets/tests/test_datasets.py | 316 ++++++++++++++++ mne_connectivity/spectral/epochs.py | 2 +- .../spectral/epochs_multivariate.py | 10 + mne_connectivity/tests/test_datasets.py | 169 --------- 10 files changed, 844 insertions(+), 172 deletions(-) create mode 100644 examples/surrogate_connectivity.py create mode 100644 mne_connectivity/datasets/surrogate.py create mode 100644 mne_connectivity/datasets/tests/test_datasets.py delete mode 100644 mne_connectivity/tests/test_datasets.py diff --git a/doc/api.rst b/doc/api.rst index 81c844187..1a84a4ad9 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -111,4 +111,5 @@ Dataset functions .. autosummary:: :toctree: generated/ - make_signals_in_freq_bands \ No newline at end of file + make_signals_in_freq_bands + make_surrogate_data \ No newline at end of file diff --git a/doc/references.bib b/doc/references.bib index 536bfad39..55cc7e338 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -61,6 +61,16 @@ @article{Dawson_2016 year = {2016} } +@article{DowdingHaufe2018, + title={Powerful statistical inference for nested data using sufficient summary statistics}, + author={Dowding, Irene and Haufe, Stefan}, + doi={10.3389/fnhum.2018.00103}, + journal={Frontiers in Human Neuroscience}, + volume={12}, + pages={103}, + year={2018} +} + @article{EwaldEtAl2012, author = {Ewald, Arne and Marzetti, Laura and Zappasodi, Filippo and Meinecke, Frank C. and Nolte, Guido}, doi = {10.1016/j.neuroimage.2011.11.084}, @@ -185,6 +195,17 @@ @book{OppenheimEtAl1999 year = {1999} } +@article{PellegriniEtAl2023, + title={Identifying good practices for detecting inter-regional linear functional connectivity from {EEG}}, + author={Pellegrini, Franziska and Delorme, Arnaud and Nikulin, Vadim and Haufe, Stefan}, + doi={10.1016/j.neuroimage.2023.120218}, + journal={NeuroImage}, + volume={277}, + pages={120218}, + year={2023}, + publisher={Elsevier} +} + @book{SekiharaNagarajan2008, author = {Sekihara, Kensuke and Nagarajan, Srikantan S.}, doi = {10.1007/978-3-540-79370-0}, diff --git a/examples/surrogate_connectivity.py b/examples/surrogate_connectivity.py new file mode 100644 index 000000000..ec726b80e --- /dev/null +++ b/examples/surrogate_connectivity.py @@ -0,0 +1,351 @@ +""" +================================================================================== +Determine the significance of connectivity estimates against baseline connectivity +================================================================================== + +This example demonstrates how surrogate data can be generated to assess whether +connectivity estimates are significantly greater than baseline. +""" + +# Author: Thomas S. Binns +# License: BSD (3-clause) +# sphinx_gallery_thumbnail_number = 3 + +# %% + +import matplotlib.pyplot as plt +import mne +import numpy as np +from mne.datasets import somato + +from mne_connectivity import make_surrogate_data, spectral_connectivity_epochs + +######################################################################################## +# Background +# ---------- +# +# When performing connectivity analyses, we often want to know whether the results we +# observe reflect genuine interactions between signals. We can assess this by performing +# statistical tests between our connectivity estimates and a 'baseline' level of +# connectivity. However, due to factors such as background noise and sample +# size-dependent biases (see e.g. :footcite:`VinckEtAl2010`), it is often not +# appropriate to treat 0 as this baseline. Therefore, we need a way to estimate the +# baseline level of connectivity. +# +# One approach is to manipulate the original data in such a way that the covariance +# structure is destroyed, creating surrogate data. Connectivity estimates from the +# original and surrogate data can then be compared to determine whether the original +# data contains significant interactions. +# +# Such surrogate data can be easily generated in MNE using the +# :func:`~mne_connectivity.make_surrogate_data` function, which shuffles epoched data +# independently across channels :footcite:`PellegriniEtAl2023` (see the Notes section of +# the function for more information). In this example, we will demonstrate how surrogate +# data can be created, and how you can use this to assess the statistical significance +# of your connectivity estimates. + +######################################################################################## +# Loading the data +# ---------------- +# +# We start by loading from the :ref:`somato-dataset` dataset, MEG data showing +# event-related activity in response to somatosensory stimuli. We construct epochs +# around these events in the time window [-1.5, 1.0] seconds. + +# %% + +# Load data +data_path = somato.data_path() +raw_fname = data_path / "sub-01" / "meg" / "sub-01_task-somato_meg.fif" +raw = mne.io.read_raw_fif(raw_fname) +events = mne.find_events(raw, stim_channel="STI 014") + +# Pre-processing +raw.pick("grad").load_data() # focus on gradiometers +raw.filter(1, 35) +raw, events = raw.resample(sfreq=100, events=events) # reduce compute time + +# Construct epochs around events +epochs = mne.Epochs( + raw, events, event_id=1, tmin=-1.5, tmax=1.0, baseline=(-0.5, 0), preload=True +) +epochs = epochs[:30] # select a subset of epochs to speed up computation + +######################################################################################## +# Assessing connectivity in non-evoked data +# ----------------------------------------- +# +# We will first demonstrate how connectivity can be assessed from non-evoked data. In +# this example, we use data from the pre-trial period of [-1.5, -0.5] seconds. We +# compute Fourier coefficients of the data using the :meth:`~mne.Epochs.compute_psd` +# method with ``output="complex"`` (note that this requires ``mne >= 1.8``). +# +# Next, we pass these coefficients to +# :func:`~mne_connectivity.spectral_connectivity_epochs` to compute connectivity using +# the imaginary part of coherency (``imcoh``). Our indices specify that connectivity +# should be computed between all pairs of channels. + +# %% + +# Compute Fourier coefficients for pre-trial data +fmin, fmax = 3, 23 +pretrial_coeffs = epochs.compute_psd( + fmin=fmin, fmax=fmax, tmin=None, tmax=-0.5, output="complex" +) +freqs = pretrial_coeffs.freqs + +# Compute connectivity for pre-trial data +indices = np.tril_indices(epochs.info["nchan"], k=-1) # all-to-all connectivity +pretrial_con = spectral_connectivity_epochs( + pretrial_coeffs, method="imcoh", indices=indices +) + +######################################################################################## +# Next, we generate the surrogate data by passing the Fourier coefficients into the +# :func:`~mne_connectivity.make_surrogate_data` function. To get a reliable estimate of +# the baseline connectivity, we perform this shuffling procedure +# :math:`\text{n}_{\text{shuffle}}` times, producing :math:`\text{n}_{\text{shuffle}}` +# surrogate datasets. We can then iterate over these shuffles and compute the +# connectivity for each one. + +# %% + +# Generate surrogate data +n_shuffles = 100 # recommended is >= 1,000; limited here to reduce compute time +pretrial_surrogates = make_surrogate_data( + pretrial_coeffs, n_shuffles=n_shuffles, rng_seed=44 +) + +# Compute connectivity for surrogate data +surrogate_con = [] +for shuffle_i, surrogate in enumerate(pretrial_surrogates): + print(f"Computing connectivity for shuffle {shuffle_i+1} of {n_shuffles}") + surrogate_con.append( + spectral_connectivity_epochs( + surrogate, method="imcoh", indices=indices, n_jobs=-1, verbose=False + ) + ) + +######################################################################################## +# We can plot the all-to-all connectivity of the pre-trial data against the surrogate +# data, averaged over all shuffles. This shows a strong degree of coupling in the alpha +# band (~8-12 Hz), with weaker coupling in the lower range of the beta band (~13-20 Hz). +# A simple visual inspection shows that connectivity in the alpha and beta bands are +# above the baseline level of connectivity estimated from the surrogate data. However, +# we need to confirm this statistically. + +# %% + +# Plot pre-trial vs. surrogate connectivity +fig, ax = plt.subplots(1, 1) +ax.plot( + freqs, + np.abs([surrogate.get_data() for surrogate in surrogate_con]).mean(axis=(0, 1)), + linestyle="--", + label="Surrogate", +) +ax.plot(freqs, np.abs(pretrial_con.get_data()).mean(axis=0), label="Original") +ax.set_xlabel("Frequency (Hz)") +ax.set_ylabel("Connectivity (A.U.)") +ax.set_title("All-to-all connectivity | Pre-trial ") +ax.legend() + +######################################################################################## +# Assessing the statistical significance of our connectivity estimates can be done with +# the following simple procedure :footcite:`PellegriniEtAl2023` +# +# :math:`p=\LARGE{\frac{\Sigma_{s=1}^Sc_s}{S}}` , +# +# :math:`c_s=\{1\text{ if }\text{Con}\leq\text{Con}_{\text{s}}\text{ },\text{ }0 +# \text{ if otherwise }` , +# +# where: :math:`p` is our p-value; :math:`s` is a given shuffle iteration of :math:`S` +# total shuffles; and :math:`c` is a binary indicator of whether the true connectivity, +# :math:`\text{Con}`, is greater than the surrogate connectivity, +# :math:`\text{Con}_{\text{s}}`, for a given shuffle. +# +# Note that for connectivity methods which produce negative scores (e.g., imaginary part +# of coherency, time-reversed Granger causality, etc...), you should take the absolute +# values before testing. Similar adjustments should be made for methods that produce +# scores centred around non-zero values (e.g., 0.5 for directed phase lag index). +# +# Below, we determine the statistical significance of connectivity in the lower beta +# band using an alpha of 0.05. Naturally, any tests involving multiple connections, +# frequencies, and/or times should be corrected for multiple comparisons. Here however, +# we average over all connections and frequencies. +# +# The test confirms our visual inspection, showing that connectivity in the lower beta +# band is significantly above the baseline level of connectivity, which we can take as +# evidence of genuine interactions in this frequency band. + +# %% + +# Find indices of lower beta frequencies +beta_freqs = np.where((freqs >= 13) & (freqs <= 20))[0] + +# Compute lower beta connectivity for pre-trial data (average connections and freqs) +beta_con_pretrial = np.abs(pretrial_con.get_data()[:, beta_freqs]).mean(axis=(0, 1)) + +# Compute lower beta connectivity for surrogate data (average connections and freqs) +beta_con_surrogate = np.abs( + [surrogate.get_data()[:, beta_freqs] for surrogate in surrogate_con] +).mean(axis=(1, 2)) + +# Compute p-value for pre-trial lower beta coupling +alpha = 0.05 +p_val = np.sum(beta_con_pretrial <= beta_con_surrogate) / n_shuffles +print(f"P < {alpha}") if p_val < alpha else print(f"P > {alpha}") + +######################################################################################## +# Assessing connectivity in evoked data +# ------------------------------------- +# +# When generating surrogate data, it is important to distinguish non-evoked data (e.g., +# resting-state, pre/inter-trial data) from evoked data (where a stimulus is presented +# or an action performed at a set time during each epoch). Critically, evoked data +# contains a temporal structure that is consistent across epochs, and thus shuffling +# epochs across channels will fail to adequately disrupt the covariance structure. +# +# Any connectivity estimates will therefore overestimate the baseline connectivity in +# your data, increasing the likelihood of type II errors (see the Notes section of +# :func:`~mne_connectivity.make_surrogate_data` for more information, and see the final +# section of this example for a demonstration). +# +# **In cases where you want to assess connectivity in evoked data, you can use +# surrogates generated from non-evoked data (of the same subject).** Here we do just +# that, comparing connectivity estimates from the pre-trial surrogates to the evoked, +# post-stimulus response ([0, 1] second). +# +# Again, there is pronounced alpha coupling (stronger than in the pre-trial data) and +# weaker beta coupling, both of which appear to be above the baseline level of +# connectivity. + +# %% + +# Compute Fourier coefficients for post-stimulus data +poststim_coeffs = epochs.compute_psd( + fmin=fmin, fmax=fmax, tmin=0, tmax=None, output="complex" +) + +# Compute connectivity for post-stimulus data +poststim_con = spectral_connectivity_epochs( + poststim_coeffs, method="imcoh", indices=indices +) + +# Plot post-stimulus vs. (pre-trial) surrogate connectivity +fig, ax = plt.subplots(1, 1) +ax.plot( + freqs, + np.abs([surrogate.get_data() for surrogate in surrogate_con]).mean(axis=(0, 1)), + linestyle="--", + label="Surrogate", +) +ax.plot(freqs, np.abs(poststim_con.get_data()).mean(axis=0), label="Original") +ax.set_xlabel("Frequency (Hz)") +ax.set_ylabel("Connectivity (A.U.)") +ax.set_title("All-to-all connectivity | Post-stimulus") +ax.legend() + +######################################################################################## +# This is also confirmed by statistical testing, with connectivity in the lower beta +# band being significantly above the baseline level of connectivity. Thus, using +# surrogate connectivity estimates from non-evoked data provides a reliable baseline for +# assessing connectivity in evoked data. + +# %% + +# Compute lower beta connectivity for post-stimulus data (average connections and freqs) +beta_con_poststim = np.abs(poststim_con.get_data()[:, beta_freqs]).mean(axis=(0, 1)) + +# Compute p-value for post-stimulus lower beta coupling +p_val = np.sum(beta_con_poststim <= beta_con_surrogate) / n_shuffles +print(f"P < {alpha}") if p_val < alpha else print(f"P > {alpha}") + +######################################################################################## +# Generating surrogate connectivity from inappropriate data +# --------------------------------------------------------- +# +# We discussed above how surrogates generated from evoked data risk overestimating the +# degree of baseline connectivity. We demonstrate this below by generating surrogates +# from the post-stimulus data. + +# %% + +# Generate surrogates from evoked data +poststim_surrogates = make_surrogate_data( + poststim_coeffs, n_shuffles=n_shuffles, rng_seed=44 +) + +# Compute connectivity for evoked surrogate data +bad_surrogate_con = [] +for shuffle_i, surrogate in enumerate(poststim_surrogates): + print(f"Computing connectivity for shuffle {shuffle_i+1} of {n_shuffles}") + bad_surrogate_con.append( + spectral_connectivity_epochs( + surrogate, method="imcoh", indices=indices, n_jobs=-1, verbose=False + ) + ) + +######################################################################################## +# Plotting the post-stimulus connectivity against the estimates from the non-evoked and +# evoked surrogate data, we see that the evoked surrogate data greatly overestimates the +# baseline connectivity in the alpha band. +# +# Although in this case the alpha connectivity was still far above the baseline from the +# evoked surrogates, this will not always be the case, and you can see how this risks +# false negative assessments that connectivity is not significantly different from +# baseline. + +# %% + +# Plot post-stimulus vs. evoked and non-evoked surrogate connectivity +fig, ax = plt.subplots(1, 1) +ax.plot( + freqs, + np.abs([surrogate.get_data() for surrogate in surrogate_con]).mean(axis=(0, 1)), + linestyle="--", + label="Surrogate (pre-stimulus)", +) +ax.plot( + freqs, + np.abs([surrogate.get_data() for surrogate in bad_surrogate_con]).mean(axis=(0, 1)), + color="C3", + linestyle="--", + label="Surrogate (post-stimulus)", +) +ax.plot( + freqs, np.abs(poststim_con.get_data()).mean(axis=0), color="C1", label="Original" +) +ax.set_xlabel("Frequency (Hz)") +ax.set_ylabel("Connectivity (A.U.)") +ax.set_title("All-to-all connectivity | Post-stimulus") +ax.legend() + +######################################################################################## +# Assessing connectivity on a group-level +# --------------------------------------- +# +# While our focus here has been on assessing the significance of connectivity on a +# single recording-level, we may also want to determine whether group-level connectivity +# estimates are significantly different from baseline. For this, we can generate +# surrogates and estimate connectivity alongside the original signals for each piece of +# data. +# +# There are multiple ways to assess the statistical significance. For example, we can +# compute p-values for each piece of data using the approach above and combine them for +# the nested data (e.g., across recordings, subjects, etc...) using Stouffer's method +# :footcite:`DowdingHaufe2018`. +# +# Alternatively, we could take the average of the surrogate connectivity estimates +# across all shuffles for each piece of data and compare them to the original +# connectivity estimates in a paired test. The :mod:`scipy.stats` and :mod:`mne.stats` +# modules have many such tools for testing this, e.g., :func:`scipy.stats.ttest_1samp`, +# :func:`mne.stats.permutation_t_test`, etc... +# +# Therefore, surrogate connectivity estimates are a powerful tool for assessing the +# significance of connectivity estimates, both on a single recording- and group-level. + +######################################################################################## +# References +# ---------- +# .. footbibliography:: diff --git a/mne_connectivity/__init__.py b/mne_connectivity/__init__.py index ce18a2849..649c28eae 100644 --- a/mne_connectivity/__init__.py +++ b/mne_connectivity/__init__.py @@ -23,7 +23,7 @@ SpectroTemporalConnectivity, TemporalConnectivity, ) -from .datasets import make_signals_in_freq_bands +from .datasets import make_signals_in_freq_bands, make_surrogate_data from .decoding import CoherencyDecomposition from .effective import phase_slope_index from .envelope import envelope_correlation, symmetric_orth diff --git a/mne_connectivity/datasets/__init__.py b/mne_connectivity/datasets/__init__.py index d5c8e2eb8..dc2296ee0 100644 --- a/mne_connectivity/datasets/__init__.py +++ b/mne_connectivity/datasets/__init__.py @@ -1 +1,2 @@ from .frequency import make_signals_in_freq_bands +from .surrogate import make_surrogate_data diff --git a/mne_connectivity/datasets/surrogate.py b/mne_connectivity/datasets/surrogate.py new file mode 100644 index 000000000..a482b5c60 --- /dev/null +++ b/mne_connectivity/datasets/surrogate.py @@ -0,0 +1,141 @@ +# Authors: Thomas S. Binns +# +# License: BSD (3-clause) + +import numpy as np +from mne.time_frequency import EpochsSpectrum, EpochsSpectrumArray +from mne.utils import _validate_type + + +def make_surrogate_data(data, n_shuffles=1000, rng_seed=None, return_generator=True): + """Create surrogate data for a null hypothesis of connectivity. + + Parameters + ---------- + data : ~mne.time_frequency.EpochsSpectrum | ~mne.time_frequency.EpochsSpectrumArray + The Fourier coefficients to create the null hypothesis surrogate data for. Can + be generated from :meth:`mne.Epochs.compute_psd` with ``output='complex'`` + (requires ``mne >= 1.8``). + n_shuffles : int (default 1000) + The number of surrogate datasets to create. + rng_seed : int | None (default None) + The seed to use for the random number generator. If `None`, no seed is + specified. + return_generator : bool (default True) + Whether or not to return the surrogate data as a :term:`generator` object + instead of a :class:`list`. This allows iterating over the surrogates without + having to keep them all in memory. + + Returns + ------- + surrogate_data : list of ~mne.time_frequency.EpochsSpectrum + The surrogate data for the null hypothesis with ``n_shuffles`` entries. Returned + as a :term:`generator` if ``return_generator=True``. + + Notes + ----- + Surrogate data is generated by randomly shuffling the order of epochs, independently + for each channel. This destroys the covariance of the data, such that connectivity + estimates should reflect the null hypothesis of no genuine connectivity between + signals (e.g., only interactions due to background noise) + :footcite:`PellegriniEtAl2023`. + + For the surrogate data to properly reflect a null hypothesis, the data which is + shuffled **must not** have a temporal structure that is consistent across epochs. + Examples of this data include evoked potentials, where a stimulus is presented or an + action performed at a set time during each epoch. Such data should not be used for + generating surrogates, as even after shuffling the epochs, it will still show a high + degree of residual connectivity between channels. As a result, connectivity + estimates from your surrogate data will capture genuine interactions, instead of the + desired background noise. Treating these estimates as a null hypothesis will + increase the likelihood of a type II (false negative) error, i.e., that there is no + significant connectivity in your data. + + Appropriate data for generating surrogates includes data from a resting state, + inter-trial period, or similar. Here, a strong temporal consistency across epochs is + not assumed, reducing the chances that connectivity information of interest is + captured in your surrogate connectivity estimates. + + In situations where you want to assess whether evoked data has significant + connectivity, you can generate your surrogate connectivity estimates from non-evoked + data (e.g., rest data, inter-trial data) and compare this to your true connectivity + estimates from the evoked data. + + Regardless of whether you are working with evoked or non-evoked data, **you should + always compare true and surrogate connectivity estimates from epochs of the same + duration**. This will ensure that spectral information is captured with the same + accuracy in both sets of connectivity estimates. Ideally, **you should also compare + true and surrogate connectivity estimates from the same number of epochs** to avoid + biases from noise (fewer epochs gives noisier estimates) or finite sample sizes + (e.g., in coherency, phase-locking value, etc... :footcite:`VinckEtAl2010`). + + .. versionadded:: 0.8 + + References + ---------- + .. footbibliography:: + """ + # Validate inputs + _validate_type( + data, + (EpochsSpectrum, EpochsSpectrumArray), + "data", + "mne.time_frequency.EpochsSpectrum or mne.time_frequency.EpochsSpectrumArray", + ) + if not np.iscomplexobj(data.get_data()): + raise TypeError("values in `data` must be complex-valued") + n_epochs, n_chans = data.get_data().shape[:2] + if n_epochs == 1: + raise ValueError("data must contain more than one epoch for shuffling") + if n_chans == 1: + raise ValueError("data must contain more than one channel for shuffling") + + _validate_type(n_shuffles, "int-like", "n_shuffles", "int") + if n_shuffles < 1: + raise ValueError("number of shuffles must be >= 1") + + _validate_type(return_generator, bool, "return_generator", "bool") + # rng_seed checked by NumPy later + + # Make surrogate data and package into EpochsSpectrum objects + surrogate_data = _shuffle_coefficients(data, n_shuffles, rng_seed) + if not return_generator: + surrogate_data = [shuffle for shuffle in surrogate_data] + + return surrogate_data + + +def _shuffle_coefficients(data, n_shuffles, rng_seed): + """Shuffle coefficients over epochs to create surrogate data. + + Surrogate data for each shuffle packaged into an EpochsSpectrum object, which are + together returned as a generator to minimise memory demand. + """ + # Extract data array and EpochsSpectrum information + data_arr = data.get_data() + state = data.__getstate__() + defaults = dict( + method=None, + fmin=None, + fmax=None, + tmin=None, + tmax=None, + picks=None, + exclude=(), + proj=None, + remove_dc=None, + n_jobs=None, + verbose=None, + ) + + # Make surrogate data + rng = np.random.default_rng(rng_seed) + for _ in range(n_shuffles): + # Shuffle epochs for each channel independently + surrogate_arr = np.zeros_like(data_arr, dtype=data_arr.dtype) + for chan_i in range(data_arr.shape[1]): + surrogate_arr[:, chan_i] = rng.permutation(data_arr[:, chan_i], axis=0) + + # Package surrogate data for this shuffle + state["data"] = surrogate_arr + yield EpochsSpectrum(state, **defaults) # return surrogate data as a generator diff --git a/mne_connectivity/datasets/tests/test_datasets.py b/mne_connectivity/datasets/tests/test_datasets.py new file mode 100644 index 000000000..70edccb54 --- /dev/null +++ b/mne_connectivity/datasets/tests/test_datasets.py @@ -0,0 +1,316 @@ +from collections.abc import Generator + +import numpy as np +import pytest +from mne import create_info +from mne.time_frequency import EpochsSpectrumArray + +from mne_connectivity import ( + make_signals_in_freq_bands, + make_surrogate_data, + seed_target_indices, + spectral_connectivity_epochs, +) + + +@pytest.mark.parametrize("n_seeds", [1, 3]) +@pytest.mark.parametrize("n_targets", [1, 3]) +@pytest.mark.parametrize("snr", [0.7, 0.4]) +@pytest.mark.parametrize("connection_delay", [0, 3, -3]) +@pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) +def test_make_signals_in_freq_bands(n_seeds, n_targets, snr, connection_delay, mode): + """Test `make_signals_in_freq_bands` simulates connectivity properly.""" + # Case with no spurious correlations (avoids tests randomly failing) + rng_seed = 0 + + # Simulate data + freq_band = (5, 10) # fmin, fmax (Hz) + sfreq = 100 # Hz + trans_bandwidth = 1 # Hz + data = make_signals_in_freq_bands( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=freq_band, + n_epochs=30, + n_times=200, + sfreq=sfreq, + trans_bandwidth=trans_bandwidth, + snr=snr, + connection_delay=connection_delay, + rng_seed=rng_seed, + ) + + # Compute connectivity + methods = ["coh", "imcoh", "dpli"] + indices = seed_target_indices( + seeds=np.arange(n_seeds), targets=np.arange(n_targets) + n_seeds + ) + fmin = 3 + fmax = sfreq // 2 + if mode == "cwt_morlet": + cwt_params = {"cwt_freqs": np.arange(fmin, fmax), "cwt_n_cycles": 3.5} + else: + cwt_params = dict() + con = spectral_connectivity_epochs( + data, + method=methods, + indices=indices, + mode=mode, + fmin=fmin, + fmax=fmax, + **cwt_params, + ) + freqs = np.array(con[0].freqs) + + # Define expected connectivity values + thresh_good = dict() + thresh_bad = dict() + # Coh + thresh_good["coh"] = (0.2, 0.9) + thresh_bad["coh"] = (0.0, 0.2) + # ImCoh + if connection_delay == 0: + thresh_good["imcoh"] = (0.0, 0.17) + thresh_bad["imcoh"] = (0.0, 0.17) + else: + thresh_good["imcoh"] = (0.17, 0.8) + thresh_bad["imcoh"] = (0.0, 0.17) + # DPLI + if connection_delay == 0: + thresh_good["dpli"] = (0.3, 0.6) + thresh_bad["dpli"] = (0.3, 0.6) + elif connection_delay > 0: + thresh_good["dpli"] = (0.5, 1) + thresh_bad["dpli"] = (0.3, 0.6) + else: + thresh_good["dpli"] = (0, 0.5) + thresh_bad["dpli"] = (0.3, 0.6) + + # Check connectivity values are acceptable + freqs_good = np.argwhere( + (freqs >= freq_band[0]) & (freqs <= freq_band[1]) + ).flatten() + freqs_bad = np.argwhere( + (freqs < freq_band[0] - trans_bandwidth * 2) + | (freqs > freq_band[1] + trans_bandwidth * 2) + ).flatten() + for method_name, method_con in zip(methods, con): + con_values = method_con.get_data() + if method_name == "imcoh": + con_values = np.abs(con_values) + # freq. band of interest + con_values_good = np.mean(con_values[:, freqs_good]) + assert ( + con_values_good >= thresh_good[method_name][0] + and con_values_good <= thresh_good[method_name][1] + ) + + # other freqs. + con_values_bad = np.mean(con_values[:, freqs_bad]) + assert ( + con_values_bad >= thresh_bad[method_name][0] + and con_values_bad <= thresh_bad[method_name][1] + ) + + +def test_make_signals_in_freq_bands_error_catch(): + """Test error catching for `make_signals_in_freq_bands`.""" + freq_band = (5, 10) + + # check bad n_seeds/targets caught + with pytest.raises( + ValueError, match="Number of seeds and targets must each be at least 1." + ): + make_signals_in_freq_bands(n_seeds=0, n_targets=1, freq_band=freq_band) + with pytest.raises( + ValueError, match="Number of seeds and targets must each be at least 1." + ): + make_signals_in_freq_bands(n_seeds=1, n_targets=0, freq_band=freq_band) + + # check bad freq_band caught + with pytest.raises(TypeError, match="Frequency band must be a tuple."): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=1) + with pytest.raises(ValueError, match="Frequency band must contain two numbers."): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=(1, 2, 3)) + + # check bad n_times + with pytest.raises(ValueError, match="Number of timepoints must be at least 1."): + make_signals_in_freq_bands( + n_seeds=1, n_targets=1, freq_band=freq_band, n_times=0 + ) + + # check bad n_epochs + with pytest.raises(ValueError, match="Number of epochs must be at least 1."): + make_signals_in_freq_bands( + n_seeds=1, n_targets=1, freq_band=freq_band, n_epochs=0 + ) + + # check bad sfreq + with pytest.raises(ValueError, match="Sampling frequency must be > 0."): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, sfreq=0) + + # check bad snr + with pytest.raises( + ValueError, match="Signal-to-noise ratio must be between 0 and 1." + ): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, snr=-1) + with pytest.raises( + ValueError, match="Signal-to-noise ratio must be between 0 and 1." + ): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, snr=2) + + # check bad connection_delay + with pytest.raises( + ValueError, + match="Connection delay must be less than the total number of timepoints.", + ): + make_signals_in_freq_bands( + n_seeds=1, + n_targets=1, + freq_band=freq_band, + n_epochs=1, + n_times=1, + connection_delay=1, + ) + + +@pytest.mark.parametrize(("snr", "should_be_significant"), ([0.3, True], [0.1, False])) +@pytest.mark.parametrize("mode", ["multitaper", "fourier"]) +def test_make_surrogate_data(snr, should_be_significant, mode): + """Test `make_surrogate_data` creates data for null hypothesis testing.""" + # Generate data + n_seeds = 2 + n_targets = 2 + freq_band = (10, 15) + n_epochs = 30 + sfreq = 100 + n_times = sfreq * 2 + n_shuffles = 1000 + rng_seed = 44 + data = make_signals_in_freq_bands( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=freq_band, + n_epochs=n_epochs, + n_times=n_times, + sfreq=sfreq, + snr=snr, # using very high SNR seems to alter properties of data beyond fband + rng_seed=rng_seed, + ) + indices = seed_target_indices( + seeds=np.arange(n_seeds), targets=np.arange(n_targets) + n_seeds + ) + + # Compute Fourier coefficients and generate surrogates + spectrum = data.compute_psd( + method="welch" if mode == "fourier" else mode, output="complex" + ) + surrogate_spectrum = make_surrogate_data( + data=spectrum, n_shuffles=1000, rng_seed=rng_seed + ) + + # Compute connectivity + con = spectral_connectivity_epochs(data=spectrum, method="coh", indices=indices) + freqs = np.array(con.freqs) + connectivity = np.zeros((n_shuffles + 1, *con.shape)) + connectivity[0] = con.get_data() # first entry is original data + for shuffle_i, shuffle_data in enumerate(surrogate_spectrum): + connectivity[shuffle_i + 1] = spectral_connectivity_epochs( + data=shuffle_data, method="coh", indices=indices, verbose=False + ).get_data() + + # Determine if connectivity significant + alpha = 0.05 + con_freqs = (freqs >= freq_band[0]) & (freqs <= freq_band[1]) + noise_freqs = np.invert(con_freqs) + + pval_con_freqs = ( + np.sum( + np.mean(connectivity[0, :, con_freqs]) # aggregate cons and freqs + <= np.mean(connectivity[1:, :, con_freqs], axis=(1, 2)) # same aggr. here + ) + / n_shuffles + ) + + pval_noise_freqs = ( + np.sum( + np.mean(connectivity[0, :, noise_freqs]) + <= np.mean(connectivity[1:, :, noise_freqs], axis=(1, 2)) + ) + / n_shuffles + ) + + if should_be_significant: + assert pval_con_freqs < alpha, f"pval_con_freqs: {pval_con_freqs}" + else: + assert pval_con_freqs >= alpha, f"pval_con_freqs: {pval_con_freqs}" + + # Freqs where nothing simulated should never be significant + assert pval_noise_freqs > alpha, f"pval_noise_freqs: {pval_noise_freqs}" + + +def test_make_surrogate_data_generator(): + """Test `return_generator` parameter works in `make_surrogate_data`.""" + # Generate random data for packaging into EpochsSpectrum + n_epochs = 5 + n_chans = 6 + n_freqs = 50 + sfreq = n_freqs * 2 + rng = np.random.default_rng(44) + data = rng.random((n_epochs, n_chans, n_freqs)).astype(np.complex128) + data += data * 1j # complex dtypes not supported for simulation, so make complex + info = create_info(ch_names=n_chans, sfreq=sfreq, ch_types="eeg") + spectrum = EpochsSpectrumArray(data=data, info=info, freqs=np.arange(n_freqs)) + + # Test generator (not) returned when requested + surrogate_data = make_surrogate_data(data=spectrum, return_generator=True) + assert isinstance(surrogate_data, Generator), type(surrogate_data) + surrogate_data = make_surrogate_data(data=spectrum, return_generator=False) + assert isinstance(surrogate_data, list), type(surrogate_data) + + +def test_make_surrogate_data_error_catch(): + """Test error catching for `make_surrogate_data`.""" + # Generate random data for packaging into EpochsSpectrum + n_epochs = 5 + n_chans = 6 + n_freqs = 50 + sfreq = n_freqs * 2 + rng = np.random.default_rng(44) + data = rng.random((n_epochs, n_chans, n_freqs)).astype(np.complex128) + data += data * 1j # complex dtypes not supported for simulation, so make complex + info = create_info(ch_names=n_chans, sfreq=sfreq, ch_types="eeg") + spectrum = EpochsSpectrumArray(data=data, info=info, freqs=np.arange(n_freqs)) + + # check bad data + with pytest.raises(TypeError, match=r"data must be an instance of.*EpochsSpectrum"): + make_surrogate_data(data=data) + with pytest.raises(TypeError, match="values in `data` must be complex-valued"): + bad_dtype_data = EpochsSpectrumArray( + data=np.abs(data), info=info, freqs=np.arange(n_freqs) + ) + make_surrogate_data(data=bad_dtype_data) + with pytest.raises(ValueError, match="data must contain more than one epoch"): + bad_nepochs_data = EpochsSpectrumArray( + data=data[[0]], info=info, freqs=np.arange(n_freqs) + ) + make_surrogate_data(data=bad_nepochs_data) + with pytest.raises(ValueError, match="data must contain more than one channel"): + bad_nchans_data = EpochsSpectrumArray( + data=data[:, [0]], + info=create_info(ch_names=1, sfreq=sfreq, ch_types="eeg"), + freqs=np.arange(n_freqs), + ) + make_surrogate_data(data=bad_nchans_data) + + # check bad n_shuffles + with pytest.raises(TypeError, match="n_shuffles must be an instance of int"): + make_surrogate_data(data=spectrum, n_shuffles="all") + with pytest.raises(ValueError, match="number of shuffles must be >= 1"): + make_surrogate_data(data=spectrum, n_shuffles=0) + with pytest.raises(ValueError, match="number of shuffles must be >= 1"): + make_surrogate_data(data=spectrum, n_shuffles=-1) + + # check bad return_generator + with pytest.raises(TypeError, match="return_generator must be an instance of bool"): + make_surrogate_data(data=spectrum, return_generator="yes") diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 4c475a900..13156d636 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -513,7 +513,7 @@ def _epoch_spectral_connectivity( if spectrum_computed: # use existing spectral info # XXX: Will need to distinguish time-resolved spectra here if support added # Select signals & freqs of interest (flexible indexing for optional tapers dim) - x_t = np.array(data)[:, sig_idx, ..., freq_mask] + x_t = np.array(data)[:, sig_idx][..., freq_mask] # split dims to avoid np.ix_ if weights is None: # also assumes no tapers dim x_t = np.expand_dims(x_t, axis=2) # CSD construction expects a tapers dim weights = np.ones((1, 1, 1)) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index 57381974f..e16b30033 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -16,6 +16,8 @@ import numpy as np from mne.epochs import BaseEpochs from mne.parallel import parallel_func +from mne.time_frequency import EpochsSpectrum, EpochsSpectrumArray +from mne.time_frequency.multitaper import _psd_from_mt from mne.utils import ProgressBar, _validate_type, logger @@ -31,6 +33,14 @@ def _check_rank_input(rank, data, indices): if "copy" in inspect.getfullargspec(data.get_data).kwonlyargs: kwargs["copy"] = False data_arr = data.get_data(**kwargs) + elif isinstance(data, (EpochsSpectrum, EpochsSpectrumArray)): + # Spectrum objs will drop bad channels, so specify picking all channels + data_arr = data.get_data(picks=np.arange(data.info["nchan"])) + # Convert to power (and aggregate over tapers) before computing rank + if "taper" in data._dims: + data_arr = _psd_from_mt(data_arr, data.weights) + else: + data_arr = (data_arr * data_arr.conj()).real else: data_arr = data diff --git a/mne_connectivity/tests/test_datasets.py b/mne_connectivity/tests/test_datasets.py deleted file mode 100644 index 4ae7d5ac2..000000000 --- a/mne_connectivity/tests/test_datasets.py +++ /dev/null @@ -1,169 +0,0 @@ -import numpy as np -import pytest - -from mne_connectivity import ( - make_signals_in_freq_bands, - seed_target_indices, - spectral_connectivity_epochs, -) - - -@pytest.mark.parametrize("n_seeds", [1, 3]) -@pytest.mark.parametrize("n_targets", [1, 3]) -@pytest.mark.parametrize("snr", [0.7, 0.4]) -@pytest.mark.parametrize("connection_delay", [0, 3, -3]) -@pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) -def test_make_signals_in_freq_bands(n_seeds, n_targets, snr, connection_delay, mode): - """Test `make_signals_in_freq_bands` simulates connectivity properly.""" - # Case with no spurious correlations (avoids tests randomly failing) - rng_seed = 0 - - # Simulate data - freq_band = (5, 10) # fmin, fmax (Hz) - sfreq = 100 # Hz - trans_bandwidth = 1 # Hz - data = make_signals_in_freq_bands( - n_seeds=n_seeds, - n_targets=n_targets, - freq_band=freq_band, - n_epochs=30, - n_times=200, - sfreq=sfreq, - trans_bandwidth=trans_bandwidth, - snr=snr, - connection_delay=connection_delay, - rng_seed=rng_seed, - ) - - # Compute connectivity - methods = ["coh", "imcoh", "dpli"] - indices = seed_target_indices( - seeds=np.arange(n_seeds), targets=np.arange(n_targets) + n_seeds - ) - fmin = 3 - fmax = sfreq // 2 - if mode == "cwt_morlet": - cwt_params = {"cwt_freqs": np.arange(fmin, fmax), "cwt_n_cycles": 3.5} - else: - cwt_params = dict() - con = spectral_connectivity_epochs( - data, - method=methods, - indices=indices, - mode=mode, - fmin=fmin, - fmax=fmax, - **cwt_params, - ) - freqs = np.array(con[0].freqs) - - # Define expected connectivity values - thresh_good = dict() - thresh_bad = dict() - # Coh - thresh_good["coh"] = (0.2, 0.9) - thresh_bad["coh"] = (0.0, 0.2) - # ImCoh - if connection_delay == 0: - thresh_good["imcoh"] = (0.0, 0.17) - thresh_bad["imcoh"] = (0.0, 0.17) - else: - thresh_good["imcoh"] = (0.17, 0.8) - thresh_bad["imcoh"] = (0.0, 0.17) - # DPLI - if connection_delay == 0: - thresh_good["dpli"] = (0.3, 0.6) - thresh_bad["dpli"] = (0.3, 0.6) - elif connection_delay > 0: - thresh_good["dpli"] = (0.5, 1) - thresh_bad["dpli"] = (0.3, 0.6) - else: - thresh_good["dpli"] = (0, 0.5) - thresh_bad["dpli"] = (0.3, 0.6) - - # Check connectivity values are acceptable - freqs_good = np.argwhere( - (freqs >= freq_band[0]) & (freqs <= freq_band[1]) - ).flatten() - freqs_bad = np.argwhere( - (freqs < freq_band[0] - trans_bandwidth * 2) - | (freqs > freq_band[1] + trans_bandwidth * 2) - ).flatten() - for method_name, method_con in zip(methods, con): - con_values = method_con.get_data() - if method_name == "imcoh": - con_values = np.abs(con_values) - # freq. band of interest - con_values_good = np.mean(con_values[:, freqs_good]) - assert ( - con_values_good >= thresh_good[method_name][0] - and con_values_good <= thresh_good[method_name][1] - ) - - # other freqs. - con_values_bad = np.mean(con_values[:, freqs_bad]) - assert ( - con_values_bad >= thresh_bad[method_name][0] - and con_values_bad <= thresh_bad[method_name][1] - ) - - -def test_make_signals_error_catch(): - """Test error catching for `make_signals_in_freq_bands`.""" - freq_band = (5, 10) - - # check bad n_seeds/targets caught - with pytest.raises( - ValueError, match="Number of seeds and targets must each be at least 1." - ): - make_signals_in_freq_bands(n_seeds=0, n_targets=1, freq_band=freq_band) - with pytest.raises( - ValueError, match="Number of seeds and targets must each be at least 1." - ): - make_signals_in_freq_bands(n_seeds=1, n_targets=0, freq_band=freq_band) - - # check bad freq_band caught - with pytest.raises(TypeError, match="Frequency band must be a tuple."): - make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=1) - with pytest.raises(ValueError, match="Frequency band must contain two numbers."): - make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=(1, 2, 3)) - - # check bad n_times - with pytest.raises(ValueError, match="Number of timepoints must be at least 1."): - make_signals_in_freq_bands( - n_seeds=1, n_targets=1, freq_band=freq_band, n_times=0 - ) - - # check bad n_epochs - with pytest.raises(ValueError, match="Number of epochs must be at least 1."): - make_signals_in_freq_bands( - n_seeds=1, n_targets=1, freq_band=freq_band, n_epochs=0 - ) - - # check bad sfreq - with pytest.raises(ValueError, match="Sampling frequency must be > 0."): - make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, sfreq=0) - - # check bad snr - with pytest.raises( - ValueError, match="Signal-to-noise ratio must be between 0 and 1." - ): - make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, snr=-1) - with pytest.raises( - ValueError, match="Signal-to-noise ratio must be between 0 and 1." - ): - make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, snr=2) - - # check bad connection_delay - with pytest.raises( - ValueError, - match="Connection delay must be less than the total number of timepoints.", - ): - make_signals_in_freq_bands( - n_seeds=1, - n_targets=1, - freq_band=freq_band, - n_epochs=1, - n_times=1, - connection_delay=1, - ) From 30dc4e7c0913d872552f40877adbc3f387fa342a Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 6 Aug 2024 12:59:58 +0200 Subject: [PATCH 09/18] Adjust n_jobs --- examples/surrogate_connectivity.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/surrogate_connectivity.py b/examples/surrogate_connectivity.py index ec726b80e..19f019bfb 100644 --- a/examples/surrogate_connectivity.py +++ b/examples/surrogate_connectivity.py @@ -13,6 +13,8 @@ # %% +from multiprocessing import cpu_count + import matplotlib.pyplot as plt import mne import numpy as np @@ -20,6 +22,8 @@ from mne_connectivity import make_surrogate_data, spectral_connectivity_epochs +n_jobs = cpu_count() // 2 + ######################################################################################## # Background # ---------- @@ -122,7 +126,7 @@ print(f"Computing connectivity for shuffle {shuffle_i+1} of {n_shuffles}") surrogate_con.append( spectral_connectivity_epochs( - surrogate, method="imcoh", indices=indices, n_jobs=-1, verbose=False + surrogate, method="imcoh", indices=indices, n_jobs=n_jobs, verbose=False ) ) @@ -282,7 +286,7 @@ print(f"Computing connectivity for shuffle {shuffle_i+1} of {n_shuffles}") bad_surrogate_con.append( spectral_connectivity_epochs( - surrogate, method="imcoh", indices=indices, n_jobs=-1, verbose=False + surrogate, method="imcoh", indices=indices, n_jobs=n_jobs, verbose=False ) ) From 2a9c1e2a472132fe6ba5f0b97c084d794be1e46e Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 6 Aug 2024 13:14:09 +0200 Subject: [PATCH 10/18] Adjust n_jobs --- examples/surrogate_connectivity.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/surrogate_connectivity.py b/examples/surrogate_connectivity.py index 19f019bfb..ca759faa7 100644 --- a/examples/surrogate_connectivity.py +++ b/examples/surrogate_connectivity.py @@ -13,8 +13,6 @@ # %% -from multiprocessing import cpu_count - import matplotlib.pyplot as plt import mne import numpy as np @@ -22,8 +20,6 @@ from mne_connectivity import make_surrogate_data, spectral_connectivity_epochs -n_jobs = cpu_count() // 2 - ######################################################################################## # Background # ---------- @@ -126,7 +122,7 @@ print(f"Computing connectivity for shuffle {shuffle_i+1} of {n_shuffles}") surrogate_con.append( spectral_connectivity_epochs( - surrogate, method="imcoh", indices=indices, n_jobs=n_jobs, verbose=False + surrogate, method="imcoh", indices=indices, verbose=False ) ) @@ -286,7 +282,7 @@ print(f"Computing connectivity for shuffle {shuffle_i+1} of {n_shuffles}") bad_surrogate_con.append( spectral_connectivity_epochs( - surrogate, method="imcoh", indices=indices, n_jobs=n_jobs, verbose=False + surrogate, method="imcoh", indices=indices, verbose=False ) ) From 50ea4cd64e3bcd38f628a449452828f25ef15707 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 6 Aug 2024 13:59:34 +0200 Subject: [PATCH 11/18] Try CPU count // 3 --- examples/surrogate_connectivity.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/surrogate_connectivity.py b/examples/surrogate_connectivity.py index ca759faa7..e2dcc41ad 100644 --- a/examples/surrogate_connectivity.py +++ b/examples/surrogate_connectivity.py @@ -13,6 +13,8 @@ # %% +from multiprocessing import cpu_count + import matplotlib.pyplot as plt import mne import numpy as np @@ -20,6 +22,8 @@ from mne_connectivity import make_surrogate_data, spectral_connectivity_epochs +n_jobs = cpu_count() // 3 + ######################################################################################## # Background # ---------- @@ -122,7 +126,7 @@ print(f"Computing connectivity for shuffle {shuffle_i+1} of {n_shuffles}") surrogate_con.append( spectral_connectivity_epochs( - surrogate, method="imcoh", indices=indices, verbose=False + surrogate, method="imcoh", indices=indices, n_jobs=n_jobs, verbose=False ) ) @@ -282,7 +286,7 @@ print(f"Computing connectivity for shuffle {shuffle_i+1} of {n_shuffles}") bad_surrogate_con.append( spectral_connectivity_epochs( - surrogate, method="imcoh", indices=indices, verbose=False + surrogate, method="imcoh", indices=indices, n_jobs=n_jobs, verbose=False ) ) From 9e4a11d8651e222b2d906f918036c6d673752c2e Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 6 Aug 2024 14:30:49 +0200 Subject: [PATCH 12/18] Try CPU count // 4 --- examples/surrogate_connectivity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/surrogate_connectivity.py b/examples/surrogate_connectivity.py index e2dcc41ad..e3adda3ef 100644 --- a/examples/surrogate_connectivity.py +++ b/examples/surrogate_connectivity.py @@ -22,7 +22,7 @@ from mne_connectivity import make_surrogate_data, spectral_connectivity_epochs -n_jobs = cpu_count() // 3 +n_jobs = cpu_count() // 4 ######################################################################################## # Background From de52286470c532b1a1472e834ce888d119101bc3 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 6 Aug 2024 14:55:20 +0200 Subject: [PATCH 13/18] Use n_jobs=1 --- examples/surrogate_connectivity.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/surrogate_connectivity.py b/examples/surrogate_connectivity.py index e3adda3ef..ca759faa7 100644 --- a/examples/surrogate_connectivity.py +++ b/examples/surrogate_connectivity.py @@ -13,8 +13,6 @@ # %% -from multiprocessing import cpu_count - import matplotlib.pyplot as plt import mne import numpy as np @@ -22,8 +20,6 @@ from mne_connectivity import make_surrogate_data, spectral_connectivity_epochs -n_jobs = cpu_count() // 4 - ######################################################################################## # Background # ---------- @@ -126,7 +122,7 @@ print(f"Computing connectivity for shuffle {shuffle_i+1} of {n_shuffles}") surrogate_con.append( spectral_connectivity_epochs( - surrogate, method="imcoh", indices=indices, n_jobs=n_jobs, verbose=False + surrogate, method="imcoh", indices=indices, verbose=False ) ) @@ -286,7 +282,7 @@ print(f"Computing connectivity for shuffle {shuffle_i+1} of {n_shuffles}") bad_surrogate_con.append( spectral_connectivity_epochs( - surrogate, method="imcoh", indices=indices, n_jobs=n_jobs, verbose=False + surrogate, method="imcoh", indices=indices, verbose=False ) ) From cece6a4b9c12e92c9a955df179351eef330ddbd7 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 7 Aug 2024 12:31:47 +0200 Subject: [PATCH 14/18] Reset tests to upstream-main --- .circleci/config.yml | 2 +- .github/workflows/linux_conda.yml | 2 +- .github/workflows/unit_tests.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index f78ff8422..62cffbb18 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -77,7 +77,7 @@ jobs: - run: name: Get Python running and install dependencies command: | - pip install git+https://github.com/tsbinns/mne-python@complex_spectrum + pip install git+https://github.com/mne-tools/mne-python@main curl https://raw.githubusercontent.com/mne-tools/mne-python/main/tools/circleci_dependencies.sh -o circleci_dependencies.sh chmod +x circleci_dependencies.sh ./circleci_dependencies.sh diff --git a/.github/workflows/linux_conda.yml b/.github/workflows/linux_conda.yml index 9b985af66..d7de45605 100644 --- a/.github/workflows/linux_conda.yml +++ b/.github/workflows/linux_conda.yml @@ -41,7 +41,7 @@ jobs: source ./get_minimal_commands.sh pip install .[test] name: 'Install dependencies' - - run: pip install git+https://github.com/tsbinns/mne-python@complex_spectrum + - run: pip install git+https://github.com/mne-tools/mne-python@main - run: pip install -e . - run: | which mne diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 0d591f0ca..803a28991 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -60,7 +60,7 @@ jobs: run: pip install --upgrade mne - name: Install MNE (main) if: matrix.mne-version == 'mne-main' - run: pip install git+https://github.com/tsbinns/mne-python@complex_spectrum + run: pip install git+https://github.com/mne-tools/mne-python@main - run: python -c "import mne; print(mne.datasets.testing.data_path(verbose=True))" if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' - name: Display versions and environment information From 17f30cd54251d8cd3add21ca358ea445cc17516e Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 5 Aug 2024 14:05:01 +0200 Subject: [PATCH 15/18] Expand test coverage --- .../spectral/tests/test_spectral.py | 51 +++++++++++++------ 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 9631b83bc..d13cbea4c 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -473,9 +473,14 @@ def test_spectral_connectivity(method, mode): @pytest.mark.skipif( not check_version("mne", "1.8"), reason="Requires MNE v1.8.0 or higher" ) +@pytest.mark.parametrize("method", ["coh", "cacoh"]) @pytest.mark.parametrize("mode", ["multitaper", "fourier"]) -def test_spectral_connectivity_epochs_spectrum_input(mode): - """Test spec_conn_epochs works with EpochsSpectrum data as input.""" +def test_spectral_connectivity_epochs_spectrum_input(method, mode): + """Test spec_conn_epochs works with EpochsSpectrum data as input. + + Important to test both bivariate and multivariate methods, as the latter involves + additional steps (e.g., rank computation). + """ # Simulation parameters & data generation sfreq = 100.0 # Hz n_seeds = 2 @@ -499,28 +504,41 @@ def test_spectral_connectivity_epochs_spectrum_input(mode): rng_seed=44, ) - indices = seed_target_indices( - seeds=np.arange(n_seeds), targets=np.arange(n_targets) + n_seeds - ) + if method == "coh": + indices = seed_target_indices( + seeds=np.arange(n_seeds), targets=np.arange(n_targets) + n_seeds + ) + else: + indices = ([np.arange(n_seeds)], [np.arange(n_targets) + n_seeds]) # Compute Fourier coefficients + kwargs = dict() + if mode == "fourier": + kwargs.update(window="hann") # default is Hamming, but we need Hanning coeffs = data.compute_psd( - method="welch" if mode == "fourier" else mode, output="complex" + method="welch" if mode == "fourier" else mode, output="complex", **kwargs ) - # Compute connectivity (just coherence) - con = spectral_connectivity_epochs(data=coeffs, method="coh", indices=indices) + # Compute connectivity + con = spectral_connectivity_epochs(data=coeffs, method=method, indices=indices) # Check connectivity from Epochs and Spectrum are equivalent; # Works for multitaper, but Welch of Spectrum and Fourier of spec_conn are slightly # off (max. abs. diff. ~0.006) even when what should be identical settings are used + con_from_epochs = spectral_connectivity_epochs( + data=data, method=method, indices=indices, mode=mode + ) if mode == "multitaper": - con_from_epochs = spectral_connectivity_epochs( - data=data, method="coh", indices=indices - ) - # spec_conn_epochs excludes freqs without at least 5 cycles, but not Spectrum - fstart = con.freqs.index(con_from_epochs.freqs[0]) - assert_allclose(con.get_data()[:, fstart:], con_from_epochs.get_data()) + atol = 0 + else: + atol = 7e-3 + # spec_conn_epochs excludes freqs without at least 5 cycles, but not Spectrum + fstart = con.freqs.index(con_from_epochs.freqs[0]) + assert_allclose( + np.abs(con.get_data()[:, fstart:]), + np.abs(con_from_epochs.get_data()), + atol=atol, + ) # Check connectivity values are as expected freqs = np.array(con.freqs) @@ -529,15 +547,16 @@ def test_spectral_connectivity_epochs_spectrum_input(mode): freqs > fband[1] + trans_bandwidth * 2 ) + # nothing for CaCoh to optimise, so use same thresholds for CaCoh and Coh if mode == "multitaper": # lower baseline for multitaper con_thresh = (0.1, 0.3) else: # higher baseline for Welch/Fourier con_thresh = (0.2, 0.4) # check freqs of simulated interaction show strong connectivity - assert_array_less(con_thresh[1], con.get_data()[:, freqs_con].mean()) + assert_array_less(con_thresh[1], np.abs(con.get_data()[:, freqs_con].mean())) # check freqs of no simulated interaction (just noise) show weak connectivity - assert_array_less(con.get_data()[:, freqs_noise].mean(), con_thresh[0]) + assert_array_less(np.abs(con.get_data()[:, freqs_noise].mean()), con_thresh[0]) # TODO: Add general test for error catching for spec_conn_epochs From a0ebe2a468c8da53f7d86ff28868ca78a29f6e0f Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 7 Aug 2024 17:45:14 +0200 Subject: [PATCH 16/18] Update example from review --- examples/surrogate_connectivity.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/surrogate_connectivity.py b/examples/surrogate_connectivity.py index ca759faa7..faab10f28 100644 --- a/examples/surrogate_connectivity.py +++ b/examples/surrogate_connectivity.py @@ -118,8 +118,8 @@ # Compute connectivity for surrogate data surrogate_con = [] -for shuffle_i, surrogate in enumerate(pretrial_surrogates): - print(f"Computing connectivity for shuffle {shuffle_i+1} of {n_shuffles}") +for shuffle_i, surrogate in enumerate(pretrial_surrogates, 1): + print(f"Computing connectivity for shuffle {shuffle_i} of {n_shuffles}") surrogate_con.append( spectral_connectivity_epochs( surrogate, method="imcoh", indices=indices, verbose=False @@ -208,8 +208,8 @@ # # Any connectivity estimates will therefore overestimate the baseline connectivity in # your data, increasing the likelihood of type II errors (see the Notes section of -# :func:`~mne_connectivity.make_surrogate_data` for more information, and see the final -# section of this example for a demonstration). +# :func:`~mne_connectivity.make_surrogate_data` for more information, and see the +# section :ref:`inappropriate-surrogate-data` for a demonstration). # # **In cases where you want to assess connectivity in evoked data, you can use # surrogates generated from non-evoked data (of the same subject).** Here we do just @@ -262,9 +262,10 @@ print(f"P < {alpha}") if p_val < alpha else print(f"P > {alpha}") ######################################################################################## +# .. _inappropriate-surrogate-data: +# # Generating surrogate connectivity from inappropriate data # --------------------------------------------------------- -# # We discussed above how surrogates generated from evoked data risk overestimating the # degree of baseline connectivity. We demonstrate this below by generating surrogates # from the post-stimulus data. @@ -278,8 +279,8 @@ # Compute connectivity for evoked surrogate data bad_surrogate_con = [] -for shuffle_i, surrogate in enumerate(poststim_surrogates): - print(f"Computing connectivity for shuffle {shuffle_i+1} of {n_shuffles}") +for shuffle_i, surrogate in enumerate(poststim_surrogates, 1): + print(f"Computing connectivity for shuffle {shuffle_i} of {n_shuffles}") bad_surrogate_con.append( spectral_connectivity_epochs( surrogate, method="imcoh", indices=indices, verbose=False From ca6013de1353c1777cee244c51124c444442a106 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 19 Aug 2024 20:28:00 +0200 Subject: [PATCH 17/18] Update surrogate example --- examples/surrogate_connectivity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/surrogate_connectivity.py b/examples/surrogate_connectivity.py index faab10f28..c83f89e9a 100644 --- a/examples/surrogate_connectivity.py +++ b/examples/surrogate_connectivity.py @@ -343,7 +343,7 @@ # modules have many such tools for testing this, e.g., :func:`scipy.stats.ttest_1samp`, # :func:`mne.stats.permutation_t_test`, etc... # -# Therefore, surrogate connectivity estimates are a powerful tool for assessing the +# Altogether, surrogate connectivity estimates are a powerful tool for assessing the # significance of connectivity estimates, both on a single recording- and group-level. ######################################################################################## From 64a322427d710641059816ecccb3a07f038c02cc Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 12 Nov 2024 18:25:23 +0000 Subject: [PATCH 18/18] Update example from review --- examples/surrogate_connectivity.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/surrogate_connectivity.py b/examples/surrogate_connectivity.py index c83f89e9a..763fe1982 100644 --- a/examples/surrogate_connectivity.py +++ b/examples/surrogate_connectivity.py @@ -170,13 +170,15 @@ # scores centred around non-zero values (e.g., 0.5 for directed phase lag index). # # Below, we determine the statistical significance of connectivity in the lower beta -# band using an alpha of 0.05. Naturally, any tests involving multiple connections, -# frequencies, and/or times should be corrected for multiple comparisons. Here however, -# we average over all connections and frequencies. +# band. We simplify this by averaging over all connections and corresponding frequency +# bins. We could of course also test the significance of each connection, each frequency +# bin, or other frequency bands such as the alpha band. Naturally, any tests involving +# multiple connections, frequencies, and/or times should be corrected for multiple +# comparisons. # # The test confirms our visual inspection, showing that connectivity in the lower beta -# band is significantly above the baseline level of connectivity, which we can take as -# evidence of genuine interactions in this frequency band. +# band is significantly above the baseline level of connectivity at an alpha of 0.05, +# which we can take as evidence of genuine interactions in this frequency band. # %% @@ -192,9 +194,8 @@ ).mean(axis=(1, 2)) # Compute p-value for pre-trial lower beta coupling -alpha = 0.05 p_val = np.sum(beta_con_pretrial <= beta_con_surrogate) / n_shuffles -print(f"P < {alpha}") if p_val < alpha else print(f"P > {alpha}") +print(f"P = {p_val:.2f}") ######################################################################################## # Assessing connectivity in evoked data @@ -259,7 +260,7 @@ # Compute p-value for post-stimulus lower beta coupling p_val = np.sum(beta_con_poststim <= beta_con_surrogate) / n_shuffles -print(f"P < {alpha}") if p_val < alpha else print(f"P > {alpha}") +print(f"P = {p_val:.2f}") ######################################################################################## # .. _inappropriate-surrogate-data: