From ec586735df6338a93dd6aae6e8de56c262f4bf20 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Fri, 28 Oct 2022 18:29:13 +0300 Subject: [PATCH 01/30] Add ciPLV Add the corrected imaginary Phase-Locking-Value into the list of available connectivity metrics. --- .../spectral/tests/test_spectral.py | 4 +- mne_connectivity/spectral/time.py | 38 ++++++++++++++++++- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 8b4c71a8..b68a50bb 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -472,7 +472,7 @@ def test_epochs_tmin_tmax(kind): assert len(w) == 1 # just one even though there were multiple epochs -@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) +@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli', 'ciplv']) @pytest.mark.parametrize( 'mode', ['cwt_morlet', 'multitaper']) @pytest.mark.parametrize('data_option', ['sync', 'random']) @@ -526,7 +526,7 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): assert np.all(con_matrix) <= 0.5 -@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) +@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli', 'ciplv']) @pytest.mark.parametrize( 'cwt_freqs', [[8., 10.], [8, 10], 10., 10]) def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs): diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 1c16052a..92b6442e 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -183,6 +183,13 @@ def spectral_connectivity_time(data, method='coh', average=False, PLV = |E[Sxy/|Sxy|]| + 'ciplv' : corrected imaginary PLV (icPLV) + :footcite:`BrunaEtAl2018` given by:: + + |E[Im(Sxy/|Sxy|)]| + ciPLV = ------------------------------------ + sqrt(1 - |E[real(Sxy/|Sxy|)]| ** 2) + 'sxy' : Cross spectrum Sxy 'pli' : Phase Lag Index (PLI) :footcite:`StamEtAl2007` given by:: @@ -440,7 +447,7 @@ def _spectral_connectivity(data, method, kernel, foi_idx, # compute for each connectivity method this_conn = {} conn_func = {'coh': _coh, 'plv': _plv, 'sxy': _cs, 'pli': _pli, - 'wpli': _wpli} + 'wpli': _wpli, 'ciplv': _ciplv} for m in method: c_func = conn_func[m] this_conn[m] = c_func(out, kernel, foi_idx, source_idx, @@ -523,6 +530,35 @@ def pairwise_plv(w_x, w_y): return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) +def _ciplv(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, + faverage): + """Pairwise corrected imaginary phase-locking value. + + Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, + n_times).""" + + def pairwise_ciplv(w_x, w_y): + s_xy = w[:, w_y] * np.conj(w[:, w_x]) + exp_dphi = s_xy / np.abs(s_xy) + exp_dphi = _smooth_spectra(exp_dphi, kernel) + + rplv = np.abs(np.mean(np.real(exp_dphi), axis=-1, keepdims=True)) + iplv = np.abs(np.mean(np.imag(exp_dphi), axis=-1, keepdims=True)) + + out = iplv / (np.sqrt(1 - rplv ** 2)) + # mean inside frequency sliding window (if needed) + if isinstance(foi_idx, np.ndarray) and faverage: + return _foi_average(out, foi_idx) + else: + return out + + # define the function to compute in parallel + parallel, p_fun, n_jobs = parallel_func( + pairwise_ciplv, n_jobs=n_jobs, verbose=verbose, total=total) + + return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) + + def _pli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, faverage, weights): """Pairwise phase-lag index. From 31306721a764ae45f1d912e1e17d874a84f09579 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Mon, 31 Oct 2022 14:52:04 +0200 Subject: [PATCH 02/30] Speed up computation All connectivity measures are now computed with only a single computation of pairwise cross spectrum. --- mne_connectivity/spectral/time.py | 231 +++++++++--------------------- 1 file changed, 68 insertions(+), 163 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 92b6442e..39cbf636 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -367,11 +367,9 @@ def spectral_connectivity_time(data, method='coh', average=False, verbose=verbose) for epoch_idx in np.arange(n_epochs): - epoch_idx = [epoch_idx] - conn_tr = _spectral_connectivity(data[epoch_idx, ...], **call_params) + conn_tr = _spectral_connectivity(data[epoch_idx], **call_params) for m in method: - conn[m][epoch_idx, ...] = np.stack(conn_tr[m], - axis=1).squeeze(axis=-1) + conn[m][epoch_idx] = np.stack(conn_tr[m], axis=0) if indices is None: conn_flat = conn @@ -380,7 +378,7 @@ def spectral_connectivity_time(data, method='coh', average=False, this_conn = np.zeros((n_epochs, n_signals, n_signals) + conn_flat[m].shape[2:], dtype=conn_flat[m].dtype) - this_conn[:, source_idx, target_idx] = conn_flat[m][:, ...] + this_conn[:, source_idx, target_idx] = conn_flat[m] this_conn = this_conn.reshape((n_epochs, n_signals ** 2,) + conn_flat[m].shape[2:]) conn[m] = this_conn @@ -416,7 +414,7 @@ def _spectral_connectivity(data, method, kernel, foi_idx, See spectral_connectivity_epochs.""" n_pairs = len(source_idx) - + data = np.expand_dims(data, axis=0) if mode == 'cwt_morlet': out = tfr_array_morlet( data, sfreq, freqs, n_cycles=n_cycles, output='complex', @@ -444,16 +442,14 @@ def _spectral_connectivity(data, method, kernel, foi_idx, else: raise ValueError("Mode must be 'cwt_morlet' or 'multitaper'.") + out = np.squeeze(out, axis=0) + # compute for each connectivity method this_conn = {} - conn_func = {'coh': _coh, 'plv': _plv, 'sxy': _cs, 'pli': _pli, - 'wpli': _wpli, 'ciplv': _ciplv} - for m in method: - c_func = conn_func[m] - this_conn[m] = c_func(out, kernel, foi_idx, source_idx, - target_idx, n_jobs=n_jobs, - verbose=verbose, total=n_pairs, - faverage=faverage, weights=weights) + conn = _parallel_con(out, method, kernel, foi_idx, source_idx, target_idx, + n_jobs, verbose, n_pairs, faverage) + for i, m in enumerate(method): + this_conn[m] = [out[i] for out in conn] return this_conn @@ -464,164 +460,73 @@ def _spectral_connectivity(data, method, kernel, foi_idx, ############################################################################### ############################################################################### -def _coh(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage, weights): - """Pairwise coherence. +def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, n_jobs, + verbose, total, faverage): + """Compute spectral connectivity in parallel. - Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, - n_times).""" + Input signal w is of shape (n_chans, n_tapers, n_freqs, n_times).""" - if weights is not None: - psd = weights * w - psd = psd * np.conj(psd) - psd = psd.real.sum(axis=2) - psd = psd * 2 / (weights * weights.conj()).real.sum(axis=0) - else: - psd = w.real ** 2 + w.imag ** 2 - psd = np.squeeze(psd, axis=2) + if 'coh' in method: + # auto spectra (faster than w * w.conj()) + s_auto = w.real ** 2 + w.imag ** 2 - # smooth the psd - psd = _smooth_spectra(psd, kernel) + # smooth the auto spectra + s_auto = _smooth_spectra(s_auto, kernel) - def pairwise_coh(w_x, w_y): - s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) + def pairwise_con(w_x, w_y): + s_xy = w[w_y] * np.conj(w[w_x]) + dphi = s_xy / np.abs(s_xy) + dphi = _smooth_spectra(dphi, kernel) s_xy = _smooth_spectra(s_xy, kernel) - s_xx = psd[:, w_x] - s_yy = psd[:, w_y] - out = np.abs(s_xy.mean(axis=-1, keepdims=True)) / \ - np.sqrt(s_xx.mean(axis=-1, keepdims=True) * - s_yy.mean(axis=-1, keepdims=True)) - # mean inside frequency sliding window (if needed) - if isinstance(foi_idx, np.ndarray) and faverage: - return _foi_average(out, foi_idx) - else: - return out - - # define the function to compute in parallel - parallel, p_fun, n_jobs = parallel_func( - pairwise_coh, n_jobs=n_jobs, verbose=verbose, total=total) - - return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) - - -def _plv(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage, weights): - """Pairwise phase-locking value. - - Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, - n_times).""" - def pairwise_plv(w_x, w_y): - s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) - exp_dphi = s_xy / np.abs(s_xy) - exp_dphi = _smooth_spectra(exp_dphi, kernel) - # mean over time - exp_dphi_mean = exp_dphi.mean(axis=-1, keepdims=True) - out = np.abs(exp_dphi_mean) - # mean inside frequency sliding window (if needed) - if isinstance(foi_idx, np.ndarray) and faverage: - return _foi_average(out, foi_idx) - else: - return out - - # define the function to compute in parallel - parallel, p_fun, n_jobs = parallel_func( - pairwise_plv, n_jobs=n_jobs, verbose=verbose, total=total) - - return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) - - -def _ciplv(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage): - """Pairwise corrected imaginary phase-locking value. - - Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, - n_times).""" - - def pairwise_ciplv(w_x, w_y): - s_xy = w[:, w_y] * np.conj(w[:, w_x]) - exp_dphi = s_xy / np.abs(s_xy) - exp_dphi = _smooth_spectra(exp_dphi, kernel) - - rplv = np.abs(np.mean(np.real(exp_dphi), axis=-1, keepdims=True)) - iplv = np.abs(np.mean(np.imag(exp_dphi), axis=-1, keepdims=True)) - - out = iplv / (np.sqrt(1 - rplv ** 2)) - # mean inside frequency sliding window (if needed) - if isinstance(foi_idx, np.ndarray) and faverage: - return _foi_average(out, foi_idx) - else: - return out - - # define the function to compute in parallel - parallel, p_fun, n_jobs = parallel_func( - pairwise_ciplv, n_jobs=n_jobs, verbose=verbose, total=total) - - return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) - - -def _pli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage, weights): - """Pairwise phase-lag index. - - Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, - n_times).""" - def pairwise_pli(w_x, w_y): - s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) - s_xy = _smooth_spectra(s_xy, kernel) - out = np.abs(np.mean(np.sign(np.imag(s_xy)), - axis=-1, keepdims=True)) - # mean inside frequency sliding window (if needed) - if isinstance(foi_idx, np.ndarray) and faverage: - return _foi_average(out, foi_idx) - else: - return out - - # define the function to compute in parallel - parallel, p_fun, n_jobs = parallel_func( - pairwise_pli, n_jobs=n_jobs, verbose=verbose, total=total) - - return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) - - -def _wpli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage, weights): - """Pairwise weighted phase-lag index. - - Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, - n_times).""" - def pairwise_wpli(w_x, w_y): - s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) - s_xy = _smooth_spectra(s_xy, kernel) - con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True)) - con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True) - out = con_num / con_den - # mean inside frequency sliding window (if needed) - if isinstance(foi_idx, np.ndarray) and faverage: - return _foi_average(out, foi_idx) - else: - return out - - # define the function to compute in parallel - parallel, p_fun, n_jobs = parallel_func( - pairwise_wpli, n_jobs=n_jobs, verbose=verbose, total=total) - - return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) - + out = [] + for m in method: + if m == 'coh': + s_xx = s_auto[w_x] + s_yy = s_auto[w_y] + coh = np.abs(s_xy.mean(axis=-1, keepdims=True)) / \ + np.sqrt(s_xx.mean(axis=-1, keepdims=True) * + s_yy.mean(axis=-1, keepdims=True)) + out.append(coh) + + if m == 'plv': + dphi_mean = dphi.mean(axis=-1, keepdims=True) + plv = np.abs(dphi_mean) + out.append(plv) + + if m == 'ciplv': + rplv = np.abs(np.mean(np.real(dphi), axis=-1, keepdims=True)) + iplv = np.abs(np.mean(np.imag(dphi), axis=-1, keepdims=True)) + ciplv = iplv / (np.sqrt(1 - rplv ** 2)) + out.append(ciplv) + + if m == 'pli': + pli = np.abs(np.mean(np.sign(np.imag(s_xy)), + axis=-1, keepdims=True)) + out.append(pli) + + if m == 'wpli': + con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True)) + con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True) + wpli = con_num / con_den + out.append(wpli) + + if m == 'cs': + out.append(s_xy) + + for i, _ in enumerate(out): + # mean over tapers + out[i] = np.mean(out[i], axis=0) + # mean inside frequency sliding window (if needed) + if isinstance(foi_idx, np.ndarray) and faverage: + out[i] = _foi_average(out[i], foi_idx) + # squeeze time dimension + out[i] = out[i].squeeze(axis=-1) -def _cs(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage, weights): - """Pairwise cross-spectra.""" - def pairwise_cs(w_x, w_y): - out = _compute_csd(w[:, w_y], w[:, w_x], weights) - out = _smooth_spectra(out, kernel) - if isinstance(foi_idx, np.ndarray) and faverage: - return _foi_average(out, foi_idx) - else: - return out + return out # define the function to compute in parallel parallel, p_fun, n_jobs = parallel_func( - pairwise_cs, n_jobs=n_jobs, verbose=verbose, total=total) + pairwise_con, n_jobs=n_jobs, verbose=verbose, total=total) return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) From 996cf460421be0423579462ab342402e23935ba2 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Mon, 31 Oct 2022 15:00:54 +0200 Subject: [PATCH 03/30] Add logging --- mne_connectivity/spectral/time.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 39cbf636..41021c76 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -367,6 +367,7 @@ def spectral_connectivity_time(data, method='coh', average=False, verbose=verbose) for epoch_idx in np.arange(n_epochs): + logger.info(f' Processing epoch {epoch_idx+1} / {n_epochs} ...') conn_tr = _spectral_connectivity(data[epoch_idx], **call_params) for m in method: conn[m][epoch_idx] = np.stack(conn_tr[m], axis=0) From 6e6fa8846f3cb0b37a679607f17a1126a46af400 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Tue, 1 Nov 2022 17:09:18 +0200 Subject: [PATCH 04/30] Add the option to specify freqs in all modes In some scenarios, users might want to specify the frequencies for time-frequency decomposition also when using multitapering. These changes allow users to specify the 'freqs' parameter to override the automatically determined frequencies. --- .../spectral/tests/test_spectral.py | 8 +-- mne_connectivity/spectral/time.py | 54 +++++++++++++------ 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index b68a50bb..2fc9c60d 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -504,11 +504,11 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): # hypothesized "connection" freq_band_low_limit = (8.) freq_band_high_limit = (13.) - cwt_freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) + freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) if mode == 'cwt_morlet' else None con = spectral_connectivity_time(data, method=method, mode=mode, sfreq=sfreq, fmin=freq_band_low_limit, fmax=freq_band_high_limit, - cwt_freqs=cwt_freqs, n_jobs=1, + freqs=freqs, n_jobs=1, faverage=True, average=True, sm_times=0) assert con.shape == (n_channels ** 2, len(con.freqs)) con_matrix = con.get_data('dense')[..., 0] @@ -589,11 +589,11 @@ def test_spectral_connectivity_time_resolved(method, mode): data = EpochsArray(data, info) # define some frequencies for cwt - freqs = np.arange(3, 20.5, 1) + freqs = np.arange(3, 20.5, 1) if mode == 'cwt_morlet' else None # run connectivity estimation con = spectral_connectivity_time( - data, sfreq=sfreq, cwt_freqs=freqs, method=method, mode=mode, + data, sfreq=sfreq, freqs=freqs, method=method, mode=mode, n_cycles=5) assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs)) assert con.get_data(output='dense').shape == \ diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 41021c76..c7832102 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -12,7 +12,7 @@ from mne.utils import (logger, warn) from ..base import (SpectralConnectivity, EpochSpectralConnectivity) -from .epochs import _compute_freqs, _compute_freq_mask +from .epochs import _compute_freq_mask from .smooth import _create_kernel, _smooth_spectra from ..utils import check_indices, fill_doc @@ -22,9 +22,8 @@ def spectral_connectivity_time(data, method='coh', average=False, indices=None, sfreq=None, fmin=None, fmax=None, fskip=0, faverage=False, sm_times=0, sm_freqs=1, sm_kernel='hanning', - mode='cwt_morlet', mt_bandwidth=None, - cwt_freqs=None, n_cycles=7, decim=1, - n_jobs=1, verbose=None): + mode='cwt_morlet', mt_bandwidth=None, freqs=None, + n_cycles=7, decim=1, n_jobs=1, verbose=None): """Compute frequency- and time-frequency-domain connectivity measures. This method computes time-resolved connectivity measures from epoched data. @@ -92,11 +91,12 @@ def spectral_connectivity_time(data, method='coh', average=False, bandwidth (thus the frequency resolution) and the number of good tapers. See :func:`mne.time_frequency.tfr_array_multitaper` documentation. - cwt_freqs : array_like + freqs : array_like Array of frequencies of interest for time-frequency decomposition. - Only used in 'cwt_morlet' mode. Only the frequencies within - the range specified by ``fmin`` and ``fmax`` are used. Required if - ``mode='cwt_morlet'``. Not used when ``mode='multitaper'``. + Required in ``cwt_morlet`` mode. Only the frequencies within + the range specified by ``fmin`` and ``fmax`` are used. Required if + ``mode='cwt_morlet'``. If set when ``mode='multitaper'``, overrides the + automatically determined frequencies of interest. n_cycles : float | array_like of float Number of cycles in the wavelet, either a fixed number or one per frequency. The number of cycles ``n_cycles`` and the frequencies of @@ -314,22 +314,22 @@ def spectral_connectivity_time(data, method='coh', average=False, target_idx = indices_use[1] n_pairs = len(source_idx) - # check cwt_freqs - if cwt_freqs is not None: + # check freqs + if freqs is not None: # check for single frequency - if isinstance(cwt_freqs, (int, float)): - cwt_freqs = [cwt_freqs] + if isinstance(freqs, (int, float)): + freqs = [freqs] # array conversion - cwt_freqs = np.asarray(cwt_freqs) + freqs = np.asarray(freqs) # check order for multiple frequencies - if len(cwt_freqs) >= 2: - delta_f = np.diff(cwt_freqs) + if len(freqs) >= 2: + delta_f = np.diff(freqs) increase = np.all(delta_f > 0) assert increase, "Frequencies should be in increasing order" # compute frequencies to analyze based on number of samples, # sampling rate, specified wavelet frequencies and mode - freqs = _compute_freqs(n_times, sfreq, cwt_freqs, mode) + freqs = _compute_freqs(n_times, sfreq, freqs, mode) # compute the mask based on specified min/max and decimation factor freq_mask = _compute_freq_mask(freqs, fmin, fmax, fskip) @@ -573,3 +573,25 @@ def _foi_average(conn, foi_idx): f_e += 1 if f_s == f_e else f_e conn_f[..., n_f, :] = conn[..., f_s:f_e, :].mean(-2) return conn_f + + +def _compute_freqs(n_times, sfreq, freqs, mode): + from scipy.fft import rfftfreq + # get frequencies of interest for the different modes + if freqs is not None: + if any(freqs > (sfreq / 2.)): + raise ValueError('entries in freqs cannot be ' + 'larger than Nyquist (sfreq / 2)') + else: + return freqs.astype(np.float64) + if mode in ('multitaper', 'fourier'): + # fmin fmax etc is only supported for these modes + # decide which frequencies to keep + return rfftfreq(n_times, 1. / sfreq) + elif mode == 'cwt_morlet': + # cwt_morlet mode + if freqs is None: + raise ValueError('define frequencies of interest using ' + 'cwt_freqs') + else: + raise ValueError('mode has an invalid value') From 26561be56f863543a072eb0b5d2bb15b6406dde8 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Thu, 10 Nov 2022 12:26:29 +0200 Subject: [PATCH 05/30] BUG: Average over CSD instead of connectivity --- mne_connectivity/spectral/time.py | 37 +++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index c7832102..ae244d6f 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -413,7 +413,8 @@ def _spectral_connectivity(data, method, kernel, foi_idx, n_jobs, verbose): """Estimate time-resolved connectivity for one epoch. - See spectral_connectivity_epochs.""" + Data is of shape (n_channels, n_times).""" + n_pairs = len(source_idx) data = np.expand_dims(data, axis=0) if mode == 'cwt_morlet': @@ -448,7 +449,7 @@ def _spectral_connectivity(data, method, kernel, foi_idx, # compute for each connectivity method this_conn = {} conn = _parallel_con(out, method, kernel, foi_idx, source_idx, target_idx, - n_jobs, verbose, n_pairs, faverage) + n_jobs, verbose, n_pairs, faverage, weights) for i, m in enumerate(method): this_conn[m] = [out[i] for out in conn] @@ -462,28 +463,42 @@ def _spectral_connectivity(data, method, kernel, foi_idx, ############################################################################### def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, n_jobs, - verbose, total, faverage): + verbose, total, faverage, weights): """Compute spectral connectivity in parallel. Input signal w is of shape (n_chans, n_tapers, n_freqs, n_times).""" if 'coh' in method: - # auto spectra (faster than w * w.conj()) - s_auto = w.real ** 2 + w.imag ** 2 + # psd + if weights is not None: + psd = weights * w + psd = psd * np.conj(psd) + psd = psd.real.sum(axis=1) + psd = psd * 2 / (weights*weights.conj()).real.sum(axis=0) + else: + psd = w.real ** 2 + w.imag ** 2 + psd = np.squeeze(psd, axis=1) - # smooth the auto spectra - s_auto = _smooth_spectra(s_auto, kernel) + # smooth + psd = _smooth_spectra(psd, kernel) def pairwise_con(w_x, w_y): - s_xy = w[w_y] * np.conj(w[w_x]) + # csd + if weights is not None: + s_xy = np.sum(weights * w[w_x] * np.conj(weights * w[w_y]), axis=0) + s_xy = s_xy * 2 / (weights * np.conj(weights)).real.sum(axis=0) + else: + s_xy = w[w_x] * np.conj(w[w_y]) + s_xy = np.squeeze(s_xy, axis=0) + dphi = s_xy / np.abs(s_xy) dphi = _smooth_spectra(dphi, kernel) s_xy = _smooth_spectra(s_xy, kernel) out = [] for m in method: if m == 'coh': - s_xx = s_auto[w_x] - s_yy = s_auto[w_y] + s_xx = psd[w_x] + s_yy = psd[w_y] coh = np.abs(s_xy.mean(axis=-1, keepdims=True)) / \ np.sqrt(s_xx.mean(axis=-1, keepdims=True) * s_yy.mean(axis=-1, keepdims=True)) @@ -515,8 +530,6 @@ def pairwise_con(w_x, w_y): out.append(s_xy) for i, _ in enumerate(out): - # mean over tapers - out[i] = np.mean(out[i], axis=0) # mean inside frequency sliding window (if needed) if isinstance(foi_idx, np.ndarray) and faverage: out[i] = _foi_average(out[i], foi_idx) From a50053f26d287af708844fc1f8358126337fdb99 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Mon, 14 Nov 2022 15:41:03 +0200 Subject: [PATCH 06/30] Add option to use part of signal as padding This adds the option to use the edges of the signal at each epoch as padding. The purpose of this is to avoid edge effects generated by the time-frequency transformation methods. --- mne_connectivity/spectral/time.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index ae244d6f..9bed6dc5 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -21,7 +21,7 @@ def spectral_connectivity_time(data, method='coh', average=False, indices=None, sfreq=None, fmin=None, fmax=None, fskip=0, faverage=False, sm_times=0, - sm_freqs=1, sm_kernel='hanning', + sm_freqs=1, sm_kernel='hanning', padding=0, mode='cwt_morlet', mt_bandwidth=None, freqs=None, n_cycles=7, decim=1, n_jobs=1, verbose=None): """Compute frequency- and time-frequency-domain connectivity measures. @@ -80,6 +80,9 @@ def spectral_connectivity_time(data, method='coh', average=False, is equivalent to no smoothing. sm_kernel : {'square', 'hanning'} Smoothing kernel type. Choose either 'square' or 'hanning'. + padding: float + Amount of time to consider as padding at the beginning and end of each + epoch in seconds. mode : str Time-frequency decomposition method. Can be either: 'multitaper', or 'cwt_morlet'. See :func:`mne.time_frequency.tfr_array_multitaper` and @@ -363,7 +366,7 @@ def spectral_connectivity_time(data, method='coh', average=False, source_idx=source_idx, target_idx=target_idx, mode=mode, sfreq=sfreq, freqs=freqs, faverage=faverage, n_cycles=n_cycles, mt_bandwidth=mt_bandwidth, - decim=decim, kw_cwt={}, kw_mt={}, n_jobs=n_jobs, + decim=decim, padding=padding, kw_cwt={}, kw_mt={}, n_jobs=n_jobs, verbose=verbose) for epoch_idx in np.arange(n_epochs): @@ -409,7 +412,7 @@ def spectral_connectivity_time(data, method='coh', average=False, def _spectral_connectivity(data, method, kernel, foi_idx, source_idx, target_idx, mode, sfreq, freqs, faverage, n_cycles, - mt_bandwidth, decim, kw_cwt, kw_mt, + mt_bandwidth, decim, padding, kw_cwt, kw_mt, n_jobs, verbose): """Estimate time-resolved connectivity for one epoch. @@ -446,6 +449,11 @@ def _spectral_connectivity(data, method, kernel, foi_idx, out = np.squeeze(out, axis=0) + if padding: + pad_idx = int(np.floor(padding * sfreq / decim)) + out = out[..., pad_idx:-pad_idx] + weights = weights[..., pad_idx:-pad_idx] + # compute for each connectivity method this_conn = {} conn = _parallel_con(out, method, kernel, foi_idx, source_idx, target_idx, From 5c778c1033a88ee777aa5b0616c09a451d760960 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Thu, 17 Nov 2022 16:12:35 +0200 Subject: [PATCH 07/30] Fix test bug, use 'freqs' instead of 'cwt_freqs' --- mne_connectivity/spectral/tests/test_spectral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 2fc9c60d..df4f3760 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -555,7 +555,7 @@ def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs): con = spectral_connectivity_time(data, method=method, mode='cwt_morlet', sfreq=sfreq, fmin=np.min(cwt_freqs), fmax=np.max(cwt_freqs), - cwt_freqs=cwt_freqs, n_jobs=1, + freqs=cwt_freqs, n_jobs=1, faverage=True, average=True, sm_times=0) assert con.shape == (n_channels ** 2, len(con.freqs)) con_matrix = con.get_data('dense')[..., 0] From 8f37ae67fa5dc69f9688136d6d3dfbe1f777cd71 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Thu, 24 Nov 2022 13:48:33 +0200 Subject: [PATCH 08/30] Fix bug with dpss windows Sym is not a parameter of dpss_windows. (But is one of the underlying scipy.signal.dpss) --- mne_connectivity/spectral/time.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 9bed6dc5..303e6f15 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -440,8 +440,7 @@ def _spectral_connectivity(data, method, kernel, foi_idx, window_length = np.arange(0., n_c / float(f), 1.0 / sfreq).shape[0] half_nbw = mt_bandwidth / 2. n_tapers = int(np.floor(mt_bandwidth - 1)) - _, eigvals = dpss_windows(window_length, half_nbw, n_tapers, - sym=False) + _, eigvals = dpss_windows(window_length, half_nbw, n_tapers) weights[:, i, :] = np.sqrt(eigvals[:, np.newaxis]) # weights have shape (n_tapers, n_freqs, n_times) else: From 4496bce0d11b066ef338ba2cfdcfbb025269843d Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Thu, 24 Nov 2022 13:55:03 +0200 Subject: [PATCH 09/30] Only show progress bar if verbosity level is DEBUG This change will skip the rendering of the connectivity computation progress bar if the logging level is not DEBUG. This is in line with MNE-Python, where progress bars are not shown at INFO or higher logging levels. Rendering the progress bar regardless of logging levels has the potential to cause unnecessary clutter in users' log files. --- mne_connectivity/spectral/time.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 303e6f15..11a32063 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -545,6 +545,10 @@ def pairwise_con(w_x, w_y): return out + # only show progress if verbosity level is DEBUG + if verbose != 'DEBUG' and verbose != 'debug' and verbose != 10: + total = None + # define the function to compute in parallel parallel, p_fun, n_jobs = parallel_func( pairwise_con, n_jobs=n_jobs, verbose=verbose, total=total) From a7680382074a75ae1b712b8752d0c88740a499bc Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Thu, 24 Nov 2022 14:42:05 +0200 Subject: [PATCH 10/30] Improve doc Add a better description of the method + style nitpicks. --- mne_connectivity/spectral/time.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 11a32063..23138596 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -24,13 +24,14 @@ def spectral_connectivity_time(data, method='coh', average=False, sm_freqs=1, sm_kernel='hanning', padding=0, mode='cwt_morlet', mt_bandwidth=None, freqs=None, n_cycles=7, decim=1, n_jobs=1, verbose=None): - """Compute frequency- and time-frequency-domain connectivity measures. + """Compute time-frequency-domain connectivity measures. - This method computes time-resolved connectivity measures from epoched data. + This function computes spectral connectivity over time from epoched data. + The data may consist of a single epoch. The connectivity method(s) are specified using the ``method`` parameter. - All methods are based on estimates of the cross- and power spectral - densities (CSD/PSD) Sxy and Sxx, Syy. + All methods are based on time-resolved estimates of the cross- and + power spectral densities (CSD/PSD) Sxy and Sxx, Syy. Parameters ---------- @@ -42,11 +43,12 @@ def spectral_connectivity_time(data, method='coh', average=False, * 'coh' : Coherence * 'plv' : Phase-Locking Value (PLV) + * 'ciplv' : Corrected imaginary Phase-Locking Value * 'sxy' : Cross-spectrum * 'pli' : Phase-Lag Index * 'wpli': Weighted Phase-Lag Index average : bool - Average connectivity scores over epochs. If True, output will be + Average connectivity scores over epochs. If ``True``, output will be an instance of :class:`SpectralConnectivity`, otherwise :class:`EpochSpectralConnectivity`. indices : tuple of array_like | None @@ -186,7 +188,7 @@ def spectral_connectivity_time(data, method='coh', average=False, PLV = |E[Sxy/|Sxy|]| - 'ciplv' : corrected imaginary PLV (icPLV) + 'ciplv' : Corrected imaginary PLV (icPLV) :footcite:`BrunaEtAl2018` given by:: |E[Im(Sxy/|Sxy|)]| From c95b0cd4e564cdad7fae99075222957973fe85bf Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Thu, 24 Nov 2022 14:46:09 +0200 Subject: [PATCH 11/30] Fix style to make flake happy --- mne_connectivity/spectral/time.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 23138596..a8debf91 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -22,8 +22,9 @@ def spectral_connectivity_time(data, method='coh', average=False, indices=None, sfreq=None, fmin=None, fmax=None, fskip=0, faverage=False, sm_times=0, sm_freqs=1, sm_kernel='hanning', padding=0, - mode='cwt_morlet', mt_bandwidth=None, freqs=None, - n_cycles=7, decim=1, n_jobs=1, verbose=None): + mode='cwt_morlet', mt_bandwidth=None, + freqs=None, n_cycles=7, decim=1, n_jobs=1, + verbose=None): """Compute time-frequency-domain connectivity measures. This function computes spectral connectivity over time from epoched data. @@ -483,7 +484,7 @@ def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, n_jobs, psd = weights * w psd = psd * np.conj(psd) psd = psd.real.sum(axis=1) - psd = psd * 2 / (weights*weights.conj()).real.sum(axis=0) + psd = psd * 2 / (weights * weights.conj()).real.sum(axis=0) else: psd = w.real ** 2 + w.imag ** 2 psd = np.squeeze(psd, axis=1) @@ -509,8 +510,8 @@ def pairwise_con(w_x, w_y): s_xx = psd[w_x] s_yy = psd[w_y] coh = np.abs(s_xy.mean(axis=-1, keepdims=True)) / \ - np.sqrt(s_xx.mean(axis=-1, keepdims=True) * - s_yy.mean(axis=-1, keepdims=True)) + np.sqrt(s_xx.mean(axis=-1, keepdims=True) * + s_yy.mean(axis=-1, keepdims=True)) out.append(coh) if m == 'plv': From c5102045dc49d6f9468663ddb0c70ed43c509ee1 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Thu, 24 Nov 2022 15:07:08 +0200 Subject: [PATCH 12/30] Fix whitespace --- mne_connectivity/spectral/time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index a8debf91..0c02b538 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -83,7 +83,7 @@ def spectral_connectivity_time(data, method='coh', average=False, is equivalent to no smoothing. sm_kernel : {'square', 'hanning'} Smoothing kernel type. Choose either 'square' or 'hanning'. - padding: float + padding : float Amount of time to consider as padding at the beginning and end of each epoch in seconds. mode : str From 6ca2aa691fd21603e0ba06a51bc950b77d59153a Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Fri, 25 Nov 2022 10:57:48 +0200 Subject: [PATCH 13/30] Fix style --- mne_connectivity/spectral/tests/test_spectral.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index df4f3760..5cbb809f 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -504,7 +504,8 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): # hypothesized "connection" freq_band_low_limit = (8.) freq_band_high_limit = (13.) - freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) if mode == 'cwt_morlet' else None + freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) \ + if mode == 'cwt_morlet' else None con = spectral_connectivity_time(data, method=method, mode=mode, sfreq=sfreq, fmin=freq_band_low_limit, fmax=freq_band_high_limit, From ba3bad2f6932573cf39aa7a4eec8571018f8bf31 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Mon, 31 Oct 2022 14:52:04 +0200 Subject: [PATCH 14/30] Refactor connectivity methods Individual functions for each connectivity methods. --- mne_connectivity/spectral/time.py | 137 +++++++++++++++++------------- 1 file changed, 78 insertions(+), 59 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 8f430b3a..c0f41744 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -492,72 +492,91 @@ def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, n_jobs, # smooth psd = _smooth_spectra(psd, kernel) - - def pairwise_con(w_x, w_y): - # csd - if weights is not None: - s_xy = np.sum(weights * w[w_x] * np.conj(weights * w[w_y]), axis=0) - s_xy = s_xy * 2 / (weights * np.conj(weights)).real.sum(axis=0) - else: - s_xy = w[w_x] * np.conj(w[w_y]) - s_xy = np.squeeze(s_xy, axis=0) - - dphi = s_xy / np.abs(s_xy) - dphi = _smooth_spectra(dphi, kernel) - s_xy = _smooth_spectra(s_xy, kernel) - out = [] - for m in method: - if m == 'coh': - s_xx = psd[w_x] - s_yy = psd[w_y] - coh = np.abs(s_xy.mean(axis=-1, keepdims=True)) / \ - np.sqrt(s_xx.mean(axis=-1, keepdims=True) * - s_yy.mean(axis=-1, keepdims=True)) - out.append(coh) - - if m == 'plv': - dphi_mean = dphi.mean(axis=-1, keepdims=True) - plv = np.abs(dphi_mean) - out.append(plv) - - if m == 'ciplv': - rplv = np.abs(np.mean(np.real(dphi), axis=-1, keepdims=True)) - iplv = np.abs(np.mean(np.imag(dphi), axis=-1, keepdims=True)) - ciplv = iplv / (np.sqrt(1 - rplv ** 2)) - out.append(ciplv) - - if m == 'pli': - pli = np.abs(np.mean(np.sign(np.imag(s_xy)), - axis=-1, keepdims=True)) - out.append(pli) - - if m == 'wpli': - con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True)) - con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True) - wpli = con_num / con_den - out.append(wpli) - - if m == 'cs': - out.append(s_xy) - - for i, _ in enumerate(out): - # mean inside frequency sliding window (if needed) - if isinstance(foi_idx, np.ndarray) and faverage: - out[i] = _foi_average(out[i], foi_idx) - # squeeze time dimension - out[i] = out[i].squeeze(axis=-1) - - return out + else: + psd = None # only show progress if verbosity level is DEBUG if verbose != 'DEBUG' and verbose != 'debug' and verbose != 10: total = None # define the function to compute in parallel - parallel, p_fun, n_jobs = parallel_func( - pairwise_con, n_jobs=n_jobs, verbose=verbose, total=total) + parallel, my_pairwise_con, n_jobs = parallel_func( + _pairwise_con, n_jobs=n_jobs, verbose=verbose, total=total) + + return parallel( + my_pairwise_con(w[s], w[t], s, t, psd, method, kernel, + foi_idx, faverage, weights) + for s, t in zip(source_idx, target_idx)) + + +def _pairwise_con(w_x, w_y, x, y, psd, method, kernel, foi_idx, + faverage, weights): + # csd + if weights is not None: + s_xy = np.sum(weights * w_x * np.conj(weights * w_y), axis=0) + s_xy = s_xy * 2 / (weights * np.conj(weights)).real.sum(axis=0) + else: + s_xy = w_x * np.conj(w_y) + s_xy = np.squeeze(s_xy, axis=0) + s_xy = _smooth_spectra(s_xy, kernel) + out = [] + conn_func = {'plv': _plv, 'ciplv': _ciplv, 'pli': _pli, 'wpli': _wpli, + 'coh': _coh, 'cs': _cs} + for m in method: + if m == 'coh': + s_xx = psd[x] + s_yy = psd[y] + out.append(conn_func[m](s_xx, s_yy, s_xy)) + else: + out.append(conn_func[m](s_xy)) + + for i, _ in enumerate(out): + # mean inside frequency sliding window (if needed) + if isinstance(foi_idx, np.ndarray) and faverage: + out[i] = _foi_average(out[i], foi_idx) + # squeeze time dimension + out[i] = out[i].squeeze(axis=-1) + + return out + + +def _plv(s_xy): + s_xy = s_xy / np.abs(s_xy) + plv = np.abs(s_xy.mean(axis=-1, keepdims=True)) + return plv + + +def _ciplv(s_xy): + s_xy = s_xy / np.abs(s_xy) + rplv = np.abs(np.mean(np.real(s_xy), axis=-1, keepdims=True)) + iplv = np.abs(np.mean(np.imag(s_xy), axis=-1, keepdims=True)) + ciplv = iplv / (np.sqrt(1 - rplv ** 2)) + return ciplv + + +def _pli(s_xy): + pli = np.abs(np.mean(np.sign(np.imag(s_xy)), + axis=-1, keepdims=True)) + return pli + + +def _wpli(s_xy): + con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True)) + con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True) + wpli = con_num / con_den + return wpli + + +def _coh(s_xx, s_yy, s_xy): + con_num = np.abs(s_xy.mean(axis=-1, keepdims=True)) + con_den = np.sqrt(s_xx.mean(axis=-1, keepdims=True) * + s_yy.mean(axis=-1, keepdims=True)) + coh = con_num / con_den + return coh + - return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) +def _cs(s_xy): + return s_xy.mean(axis=-1, keepdims=True) def _compute_csd(x, y, weights): From 71d1259b1c34cb607a73d540587deb67a109ec11 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Thu, 8 Dec 2022 16:15:26 +0200 Subject: [PATCH 15/30] Fix parallelization This change makes joblib happy to do multithreading. --- mne_connectivity/spectral/time.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index c0f41744..ce95485e 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -504,14 +504,14 @@ def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, n_jobs, _pairwise_con, n_jobs=n_jobs, verbose=verbose, total=total) return parallel( - my_pairwise_con(w[s], w[t], s, t, psd, method, kernel, + my_pairwise_con(w, psd, s, t, method, kernel, foi_idx, faverage, weights) for s, t in zip(source_idx, target_idx)) -def _pairwise_con(w_x, w_y, x, y, psd, method, kernel, foi_idx, - faverage, weights): - # csd +def _pairwise_con(w, psd, x, y, method, kernel, foi_idx, + faverage, weights): + w_x, w_y = w[x], w[y] if weights is not None: s_xy = np.sum(weights * w_x * np.conj(weights * w_y), axis=0) s_xy = s_xy * 2 / (weights * np.conj(weights)).real.sum(axis=0) From 7ea1af869b25111006817f7654d379955d19a8c6 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Thu, 8 Dec 2022 17:02:45 +0200 Subject: [PATCH 16/30] Add docstring --- mne_connectivity/spectral/time.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index ce95485e..9a2fac7c 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -511,6 +511,7 @@ def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, n_jobs, def _pairwise_con(w, psd, x, y, method, kernel, foi_idx, faverage, weights): + """Compute spectral connectivity metrics between two signals.""" w_x, w_y = w[x], w[y] if weights is not None: s_xy = np.sum(weights * w_x * np.conj(weights * w_y), axis=0) From 5215fbe677e7fc717bf3a5f6e509998470db0a09 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Thu, 5 Jan 2023 17:11:26 +0200 Subject: [PATCH 17/30] Require freqs in all tfr modes The user is required to specify the wavelet central frequencies in both multitaper and cwt_morlet tfr mode. The reasoning is that the underlying tfr implementations are very similar. This is in contrast to spectral_connectivity_epochs, where multitaper assumes that the spectrum is stationary and therefore no wavelets are used. --- .../spectral/tests/test_spectral.py | 28 ++--- mne_connectivity/spectral/time.py | 106 ++++++------------ 2 files changed, 51 insertions(+), 83 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 5cbb809f..2678b2a4 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -504,12 +504,11 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): # hypothesized "connection" freq_band_low_limit = (8.) freq_band_high_limit = (13.) - freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) \ - if mode == 'cwt_morlet' else None - con = spectral_connectivity_time(data, method=method, mode=mode, + freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) + con = spectral_connectivity_time(data, freqs, method=method, mode=mode, sfreq=sfreq, fmin=freq_band_low_limit, fmax=freq_band_high_limit, - freqs=freqs, n_jobs=1, + n_jobs=1, faverage=True, average=True, sm_times=0) assert con.shape == (n_channels ** 2, len(con.freqs)) con_matrix = con.get_data('dense')[..., 0] @@ -529,10 +528,11 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): @pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli', 'ciplv']) @pytest.mark.parametrize( - 'cwt_freqs', [[8., 10.], [8, 10], 10., 10]) -def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs): + 'freqs', [[8., 10.], [8, 10], 10., 10]) +@pytest.mark.parametrize('mode', ['cwt_morlet', 'multitaper']) +def test_spectral_connectivity_time_freqs(method, freqs, mode): """Test time-resolved spectral connectivity with int and float values for - cwt_freqs.""" + freqs.""" rng = np.random.default_rng(0) n_epochs = 5 n_channels = 3 @@ -553,10 +553,10 @@ def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs): data[i, c] = np.squeeze(np.sin(x)) # the frequency band should contain the frequency at which there is a # hypothesized "connection" - con = spectral_connectivity_time(data, method=method, mode='cwt_morlet', - sfreq=sfreq, fmin=np.min(cwt_freqs), - fmax=np.max(cwt_freqs), - freqs=cwt_freqs, n_jobs=1, + con = spectral_connectivity_time(data, freqs, method=method, + mode=mode, sfreq=sfreq, + fmin=np.min(freqs), + fmax=np.max(freqs), n_jobs=1, faverage=True, average=True, sm_times=0) assert con.shape == (n_channels ** 2, len(con.freqs)) con_matrix = con.get_data('dense')[..., 0] @@ -589,12 +589,12 @@ def test_spectral_connectivity_time_resolved(method, mode): info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg') data = EpochsArray(data, info) - # define some frequencies for cwt - freqs = np.arange(3, 20.5, 1) if mode == 'cwt_morlet' else None + # define some frequencies for tfr + freqs = np.arange(3, 20.5, 1) # run connectivity estimation con = spectral_connectivity_time( - data, sfreq=sfreq, freqs=freqs, method=method, mode=mode, + data, freqs, sfreq=sfreq, method=method, mode=mode, n_cycles=5) assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs)) assert con.get_data(output='dense').shape == \ diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 9a2fac7c..54c3a398 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -9,7 +9,7 @@ from mne.parallel import parallel_func from mne.time_frequency import (tfr_array_morlet, tfr_array_multitaper, dpss_windows) -from mne.utils import (logger, warn, verbose) +from mne.utils import (logger, verbose) from ..base import (SpectralConnectivity, EpochSpectralConnectivity) from .epochs import _compute_freq_mask @@ -19,12 +19,12 @@ @verbose @fill_doc -def spectral_connectivity_time(data, method='coh', average=False, +def spectral_connectivity_time(data, freqs, method='coh', average=False, indices=None, sfreq=None, fmin=None, fmax=None, fskip=0, faverage=False, sm_times=0, sm_freqs=1, sm_kernel='hanning', padding=0, mode='cwt_morlet', mt_bandwidth=None, - freqs=None, n_cycles=7, decim=1, n_jobs=1, + n_cycles=7, decim=1, n_jobs=1, verbose=None): """Compute time-frequency-domain connectivity measures. @@ -39,10 +39,13 @@ def spectral_connectivity_time(data, method='coh', average=False, ---------- data : array_like, shape (n_epochs, n_signals, n_times) | Epochs The data from which to compute connectivity. + freqs : array_like + Array of frequencies of interest for time-frequency decomposition. + Only the frequencies within the range specified by ``fmin`` and + ``fmax`` are used. method : str | list of str Connectivity measure(s) to compute. These can be ``['coh', 'plv', 'sxy', 'pli', 'wpli']``. These are: - * 'coh' : Coherence * 'plv' : Phase-Locking Value (PLV) * 'ciplv' : Corrected imaginary Phase-Locking Value @@ -63,12 +66,11 @@ def spectral_connectivity_time(data, method='coh', average=False, fmin : float | tuple of float | None The lower frequency of interest. Multiple bands are defined using a tuple, e.g., ``(8., 20.)`` for two bands with 8 Hz and 20 Hz lower - bounds. If `None`, the frequency corresponding to an epoch length of - 5 cycles is used. + bounds. If `None`, the lowest frequency in ``freqs`` is used. fmax : float | tuple of float | None The upper frequency of interest. Multiple bands are defined using a tuple, e.g. ``(13., 30.)`` for two band with 13 Hz and 30 Hz upper - bounds. If `None`, ``sfreq/2`` is used. + bounds. If `None`, the highest frequency in ``freqs`` is used. fskip : int Omit every ``(fskip + 1)``-th frequency bin to decimate in frequency domain. @@ -98,12 +100,6 @@ def spectral_connectivity_time(data, method='coh', average=False, bandwidth (thus the frequency resolution) and the number of good tapers. See :func:`mne.time_frequency.tfr_array_multitaper` documentation. - freqs : array_like - Array of frequencies of interest for time-frequency decomposition. - Required in ``cwt_morlet`` mode. Only the frequencies within - the range specified by ``fmin`` and ``fmax`` are used. Required if - ``mode='cwt_morlet'``. If set when ``mode='multitaper'``, overrides the - automatically determined frequencies of interest. n_cycles : float | array_like of float Number of cycles in the wavelet, either a fixed number or one per frequency. The number of cycles ``n_cycles`` and the frequencies of @@ -269,25 +265,13 @@ def spectral_connectivity_time(data, method='coh', average=False, if isinstance(method, str): method = [method] - # check that fmin corresponds to at least 5 cycles - dur = float(n_times) / sfreq - five_cycle_freq = 5. / dur + # defaults for fmin and fmax if fmin is None: - # use the 5 cycle freq. as default - fmin = five_cycle_freq - logger.info(f'Fmin was not specified. Using fmin={fmin:.2f}, which ' - 'corresponds to at least five cycles.') - else: - if np.any(fmin < five_cycle_freq): - warn('fmin=%0.3f Hz corresponds to %0.3f < 5 cycles ' - 'based on the epoch length %0.3f sec, need at least %0.3f ' - 'sec epochs or fmin=%0.3f. Spectrum estimate will be ' - 'unreliable.' % (np.min(fmin), dur * np.min(fmin), dur, - 5. / np.min(fmin), five_cycle_freq)) + fmin = np.min(freqs) + logger.info('Fmin was not specified. Using fmin=min(freqs)') if fmax is None: - fmax = sfreq / 2 - logger.info(f'Fmax was not specified. Using fmax={fmax:.2f}, which ' - f'corresponds to Nyquist.') + fmax = np.max(freqs) + logger.info('Fmax was not specified. Using fmax=max(freqs).') fmin = np.array((fmin,), dtype=float).ravel() fmax = np.array((fmax,), dtype=float).ravel() @@ -322,23 +306,29 @@ def spectral_connectivity_time(data, method='coh', average=False, n_pairs = len(source_idx) # check freqs - if freqs is not None: - # check for single frequency - if isinstance(freqs, (int, float)): - freqs = [freqs] - # array conversion - freqs = np.asarray(freqs) - # check order for multiple frequencies - if len(freqs) >= 2: - delta_f = np.diff(freqs) - increase = np.all(delta_f > 0) - assert increase, "Frequencies should be in increasing order" - - # compute frequencies to analyze based on number of samples, - # sampling rate, specified wavelet frequencies and mode - freqs = _compute_freqs(n_times, sfreq, freqs, mode) - - # compute the mask based on specified min/max and decimation factor + if isinstance(freqs, (int, float)): + freqs = [freqs] + # array conversion + freqs = np.asarray(freqs) + # check order for multiple frequencies + if len(freqs) >= 2: + delta_f = np.diff(freqs) + increase = np.all(delta_f > 0) + assert increase, "Frequencies should be in increasing order" + + # check that freqs corresponds to at least n_cycles cycles + dur = float(n_times) / sfreq + cycle_freq = n_cycles / dur + if np.any(freqs < cycle_freq): + raise ValueError('At least one value in n_cycles corresponds to a' + 'wavelet longer than the signal. Use less cycles, ' + 'higher frequencies, or longer epochs.') + # check for Nyquist + if np.any(freqs > sfreq / 2): + raise ValueError(f'Frequencies {freqs[freqs > sfreq / 2]} Hz are ' + f'larger than Nyquist = {sfreq / 2:.2f} Hz') + + # compute frequency mask based on specified min/max and decimation factor freq_mask = _compute_freq_mask(freqs, fmin, fmax, fskip) # the frequency points where we compute connectivity @@ -621,25 +611,3 @@ def _foi_average(conn, foi_idx): f_e += 1 if f_s == f_e else f_e conn_f[..., n_f, :] = conn[..., f_s:f_e, :].mean(-2) return conn_f - - -def _compute_freqs(n_times, sfreq, freqs, mode): - from scipy.fft import rfftfreq - # get frequencies of interest for the different modes - if freqs is not None: - if any(freqs > (sfreq / 2.)): - raise ValueError('entries in freqs cannot be ' - 'larger than Nyquist (sfreq / 2)') - else: - return freqs.astype(np.float64) - if mode in ('multitaper', 'fourier'): - # fmin fmax etc is only supported for these modes - # decide which frequencies to keep - return rfftfreq(n_times, 1. / sfreq) - elif mode == 'cwt_morlet': - # cwt_morlet mode - if freqs is None: - raise ValueError('define frequencies of interest using ' - 'cwt_freqs') - else: - raise ValueError('mode has an invalid value') From 8d0223f1299f1b12a056248f5026ff97e6ff7dcf Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Fri, 6 Jan 2023 16:08:29 +0200 Subject: [PATCH 18/30] Add error checks for padding --- mne_connectivity/spectral/time.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 54c3a398..a742580e 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -443,9 +443,15 @@ def _spectral_connectivity(data, method, kernel, foi_idx, out = np.squeeze(out, axis=0) if padding: + if padding < 0: + raise ValueError(f'Padding cannot be negative, got {padding}.') + if padding >= data.shape[-1] / sfreq / 2: + raise ValueError(f'Padding cannot be larger than half of data ' + f'length, got {padding}.') pad_idx = int(np.floor(padding * sfreq / decim)) out = out[..., pad_idx:-pad_idx] - weights = weights[..., pad_idx:-pad_idx] + weights = weights[..., pad_idx:-pad_idx] if weights is not None \ + else None # compute for each connectivity method this_conn = {} From ee08bba3389a678f531eb0dce7d54fb4be48f79c Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Fri, 6 Jan 2023 16:08:37 +0200 Subject: [PATCH 19/30] Add test for padding --- .../spectral/tests/test_spectral.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 2678b2a4..c24b3914 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -614,6 +614,62 @@ def test_spectral_connectivity_time_resolved(method, mode): for idx, jdx in triu_inds) +@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) +@pytest.mark.parametrize( + 'mode', ['cwt_morlet', 'multitaper']) +@pytest.mark.parametrize('padding', [0, 1, 5]) +def test_spectral_connectivity_time_padding(method, mode, padding): + """Test time-resolved spectral connectivity.""" + sfreq = 50. + n_signals = 3 + n_epochs = 2 + n_times = 300 + trans_bandwidth = 2. + tmin = 0. + tmax = (n_times - 1) / sfreq + # 5Hz..15Hz + fstart, fend = 5.0, 15.0 + data, _ = create_test_dataset( + sfreq, n_signals=n_signals, n_epochs=n_epochs, n_times=n_times, + tmin=tmin, tmax=tmax, + fstart=fstart, fend=fend, trans_bandwidth=trans_bandwidth) + ch_names = np.arange(n_signals).astype(str).tolist() + info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg') + data = EpochsArray(data, info) + + # define some frequencies for tfr + freqs = np.arange(3, 20.5, 1) + + # run connectivity estimation + if padding == 5: + with pytest.raises(ValueError): + con = spectral_connectivity_time( + data, freqs, sfreq=sfreq, method=method, mode=mode, + n_cycles=5, padding=padding) + return + else: + con = spectral_connectivity_time( + data, freqs, sfreq=sfreq, method=method, mode=mode, + n_cycles=5, padding=padding) + + assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs)) + assert con.get_data(output='dense').shape == \ + (n_epochs, n_signals, n_signals, len(con.freqs)) + + # test the simulated signal + triu_inds = np.vstack(np.triu_indices(n_signals, k=1)).T + + # average over frequencies + conn_data = con.get_data(output='dense').mean(axis=-1) + + # the indices at which there is a correlation should be greater + # then the rest of the components + for epoch_idx in range(n_epochs): + high_conn_val = conn_data[epoch_idx, 0, 1] + assert all(high_conn_val >= conn_data[epoch_idx, idx, jdx] + for idx, jdx in triu_inds) + + def test_save(tmp_path): """Test saving results of spectral connectivity.""" rng = np.random.RandomState(0) From 107dbb3c46a5bc82a13f1ca471ffca84e1e852ce Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Fri, 6 Jan 2023 16:18:00 +0200 Subject: [PATCH 20/30] Change whitespace --- mne_connectivity/spectral/time.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index a742580e..b8fb0c1e 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -24,8 +24,7 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, fmax=None, fskip=0, faverage=False, sm_times=0, sm_freqs=1, sm_kernel='hanning', padding=0, mode='cwt_morlet', mt_bandwidth=None, - n_cycles=7, decim=1, n_jobs=1, - verbose=None): + n_cycles=7, decim=1, n_jobs=1, verbose=None): """Compute time-frequency-domain connectivity measures. This function computes spectral connectivity over time from epoched data. @@ -41,7 +40,7 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, The data from which to compute connectivity. freqs : array_like Array of frequencies of interest for time-frequency decomposition. - Only the frequencies within the range specified by ``fmin`` and + Only the frequencies within the range specified by ``fmin`` and ``fmax`` are used. method : str | list of str Connectivity measure(s) to compute. These can be From 5ea3c4918da39c747b8861085045b78f79d70199 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Fri, 6 Jan 2023 16:26:13 +0200 Subject: [PATCH 21/30] Update whats_new.rst --- doc/whats_new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 503ab0af..758426d2 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -29,6 +29,8 @@ Enhancements - Improve the documentation of :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`). - Add the option to average connectivity across epochs and frequencies in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`). - Select multitaper frequencies automatically in :func:`mne_connectivity.spectral_connectivity_time` similarly to :func:`mne_connectivity.spectral_connectivity_epochs` by `Santeri Ruuskanen`_ (:gh:`104`). +- Add the ``ciPLV`` method in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`115`). +- Add the option to use the edges of each epoch as padding in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`115`). Bug ~~~ From dd45898b6f4eebcb86757078f7d322c24148f188 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Fri, 6 Jan 2023 12:54:43 -0500 Subject: [PATCH 22/30] Fix indentation and sphinx version Signed-off-by: Adam Li --- mne_connectivity/spectral/time.py | 16 ++++++++-------- requirements_doc.txt | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index b8fb0c1e..bf316e9a 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -45,12 +45,12 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, method : str | list of str Connectivity measure(s) to compute. These can be ``['coh', 'plv', 'sxy', 'pli', 'wpli']``. These are: - * 'coh' : Coherence - * 'plv' : Phase-Locking Value (PLV) - * 'ciplv' : Corrected imaginary Phase-Locking Value - * 'sxy' : Cross-spectrum - * 'pli' : Phase-Lag Index - * 'wpli': Weighted Phase-Lag Index + * 'coh' : Coherence + * 'plv' : Phase-Locking Value (PLV) + * 'ciplv' : Corrected imaginary Phase-Locking Value + * 'sxy' : Cross-spectrum + * 'pli' : Phase-Lag Index + * 'wpli' : Weighted Phase-Lag Index average : bool Average connectivity scores over epochs. If ``True``, output will be an instance of :class:`SpectralConnectivity`, otherwise @@ -185,8 +185,8 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, PLV = |E[Sxy/|Sxy|]| - 'ciplv' : Corrected imaginary PLV (icPLV) - :footcite:`BrunaEtAl2018` given by:: + 'ciplv' : Corrected imaginary PLV (icPLV) :footcite:`BrunaEtAl2018` + given by:: |E[Im(Sxy/|Sxy|)]| ciPLV = ------------------------------------ diff --git a/requirements_doc.txt b/requirements_doc.txt index 51d4fa53..28908b5e 100644 --- a/requirements_doc.txt +++ b/requirements_doc.txt @@ -1,5 +1,5 @@ memory_profiler -sphinx +sphinx<6.0 sphinx-gallery sphinx_rtd_theme sphinx-copybutton From ca3a8f605c320324efced32834f29577bd560eca Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen <66060772+ruuskas@users.noreply.github.com> Date: Fri, 6 Jan 2023 19:54:45 +0200 Subject: [PATCH 23/30] Update doc Co-authored-by: Adam Li --- mne_connectivity/spectral/tests/test_spectral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index c24b3914..16c292fc 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -619,7 +619,7 @@ def test_spectral_connectivity_time_resolved(method, mode): 'mode', ['cwt_morlet', 'multitaper']) @pytest.mark.parametrize('padding', [0, 1, 5]) def test_spectral_connectivity_time_padding(method, mode, padding): - """Test time-resolved spectral connectivity.""" + """Test time-resolved spectral connectivity with padding.""" sfreq = 50. n_signals = 3 n_epochs = 2 From 157ad2bf6fd38c1fe389548737aca54b4818707f Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Mon, 9 Jan 2023 11:31:27 +0200 Subject: [PATCH 24/30] Check for specific value error --- mne_connectivity/spectral/tests/test_spectral.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 16c292fc..3286bba8 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -642,7 +642,8 @@ def test_spectral_connectivity_time_padding(method, mode, padding): # run connectivity estimation if padding == 5: - with pytest.raises(ValueError): + with pytest.raises(ValueError, match='Padding cannot be larger than ' + 'half of data length'): con = spectral_connectivity_time( data, freqs, sfreq=sfreq, method=method, mode=mode, n_cycles=5, padding=padding) From 726232b735e347ebfff40c3642569484bc42b1cd Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Mon, 9 Jan 2023 11:39:24 +0200 Subject: [PATCH 25/30] Add note on padding --- mne_connectivity/spectral/time.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index bf316e9a..48d4bbc6 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -151,6 +151,14 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, using a weighted average, where the weights correspond to the concentration ratios between the DPSS windows. + Spectral estimation using multitaper or Morlet wavelets introduces edge + effects that depend on the length of the wavelet. To remove edge effects, + the parameter ``padding`` can be used to prune the edges of the signal. + Please see the documentation of + :func:`mne.time_frequency.tfr_array_multitaper` and + :func:`mne.time_frequency.tfr_array_morlet` for details on wavelet length + (i.e., time window length). + By default, the connectivity between all signals is computed (only connections corresponding to the lower-triangular part of the connectivity matrix). If one is only interested in the connectivity From 4dd6214f5f5dffe9f4da99b24dc8a6f4ad3ce175 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Mon, 9 Jan 2023 11:43:27 +0200 Subject: [PATCH 26/30] Remove cs --- mne_connectivity/spectral/time.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 48d4bbc6..2e08c984 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -44,11 +44,10 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, ``fmax`` are used. method : str | list of str Connectivity measure(s) to compute. These can be - ``['coh', 'plv', 'sxy', 'pli', 'wpli']``. These are: + ``['coh', 'plv', 'pli', 'wpli']``. These are: * 'coh' : Coherence * 'plv' : Phase-Locking Value (PLV) * 'ciplv' : Corrected imaginary Phase-Locking Value - * 'sxy' : Cross-spectrum * 'pli' : Phase-Lag Index * 'wpli' : Weighted Phase-Lag Index average : bool @@ -200,8 +199,6 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, ciPLV = ------------------------------------ sqrt(1 - |E[real(Sxy/|Sxy|)]| ** 2) - 'sxy' : Cross spectrum Sxy - 'pli' : Phase Lag Index (PLI) :footcite:`StamEtAl2007` given by:: PLI = |E[sign(Im(Sxy))]| @@ -441,7 +438,8 @@ def _spectral_connectivity(data, method, kernel, foi_idx, window_length = np.arange(0., n_c / float(f), 1.0 / sfreq).shape[0] half_nbw = mt_bandwidth / 2. n_tapers = int(np.floor(mt_bandwidth - 1)) - _, eigvals = dpss_windows(window_length, half_nbw, n_tapers) + _, eigvals = dpss_windows(window_length, half_nbw, n_tapers, + sym=False) weights[:, i, :] = np.sqrt(eigvals[:, np.newaxis]) # weights have shape (n_tapers, n_freqs, n_times) else: @@ -525,7 +523,7 @@ def _pairwise_con(w, psd, x, y, method, kernel, foi_idx, s_xy = _smooth_spectra(s_xy, kernel) out = [] conn_func = {'plv': _plv, 'ciplv': _ciplv, 'pli': _pli, 'wpli': _wpli, - 'coh': _coh, 'cs': _cs} + 'coh': _coh} for m in method: if m == 'coh': s_xx = psd[x] @@ -579,10 +577,6 @@ def _coh(s_xx, s_yy, s_xy): return coh -def _cs(s_xy): - return s_xy.mean(axis=-1, keepdims=True) - - def _compute_csd(x, y, weights): """Compute cross spectral density of signals x and y.""" if weights is not None: From 6391560a19c440e9828eec0ef73b81f3970ad87a Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Mon, 9 Jan 2023 11:44:09 +0200 Subject: [PATCH 27/30] Add ciplv in doc --- mne_connectivity/spectral/time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 2e08c984..7599a79e 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -44,7 +44,7 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, ``fmax`` are used. method : str | list of str Connectivity measure(s) to compute. These can be - ``['coh', 'plv', 'pli', 'wpli']``. These are: + ``['coh', 'plv', 'ciplv', 'pli', 'wpli']``. These are: * 'coh' : Coherence * 'plv' : Phase-Locking Value (PLV) * 'ciplv' : Corrected imaginary Phase-Locking Value From 04426de1e80dc4ee6df81857397dea6c61d4705a Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Mon, 9 Jan 2023 11:51:55 +0200 Subject: [PATCH 28/30] Require mne>=1.3 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 067ea5c7..a564429a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ numpy scipy -mne>=1.1 +mne>=1.3 xarray netCDF4 h5netcdf From 3a54f9954fa60c52bb635bb67271bb162fac0805 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen <66060772+ruuskas@users.noreply.github.com> Date: Mon, 9 Jan 2023 18:13:37 +0200 Subject: [PATCH 29/30] Update doc Co-authored-by: Adam Li --- mne_connectivity/spectral/time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 7599a79e..b7b023cb 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -86,7 +86,7 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, Smoothing kernel type. Choose either 'square' or 'hanning'. padding : float Amount of time to consider as padding at the beginning and end of each - epoch in seconds. + epoch in seconds. See Notes for more information. mode : str Time-frequency decomposition method. Can be either: 'multitaper', or 'cwt_morlet'. See :func:`mne.time_frequency.tfr_array_multitaper` and From 0d6a1c42e9013fa2deb3297fe2f9ea1339680bf8 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Tue, 10 Jan 2023 15:02:19 +0200 Subject: [PATCH 30/30] Update private function docstrings --- mne_connectivity/spectral/time.py | 117 +++++++++++++++++++++++++++--- 1 file changed, 107 insertions(+), 10 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 7599a79e..25e4f0d0 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -414,8 +414,49 @@ def _spectral_connectivity(data, method, kernel, foi_idx, n_jobs, verbose): """Estimate time-resolved connectivity for one epoch. - Data is of shape (n_channels, n_times).""" + Parameters + ---------- + data : array_like, shape (n_channels, n_times) + Time-series data. + method : list of str + List of connectivity metrics to compute. + kernel : array_like, shape (n_sm_fres, n_sm_times) + Smoothing kernel. + foi_idx : array_like, shape (n_foi, 2) + Upper and lower bound indices of frequency bands. + source_idx : array_like, shape (n_pairs,) + Defines the signal pairs of interest together with ``target_idx``. + target_idx : array_like, shape (n_pairs,) + Defines the signal pairs of interest together with ``source_idx``. + mode : str + Time-frequency transformation method. + sfreq : float + Sampling frequency. + freqs : array_like + Array of frequencies of interest for time-frequency decomposition. + Only the frequencies within the range specified by ``fmin`` and + ``fmax`` are used. + faverage : bool + Average over frequency bands. + n_cycles : float | array_like of float + Number of cycles in the wavelet, either a fixed number or one per + frequency. + mt_bandwidth : float | None + Multitaper time-bandwidth. + decim : int + Decimation factor after time-frequency + decomposition. + padding : float + Amount of time to consider as padding at the beginning and end of each + epoch in seconds. + Returns + ------- + this_conn : list of array + List of connectivity estimates corresponding to the metrics in + ``method``. Each element is an array of shape (n_pairs, n_freqs) or + (n_pairs, n_fbands) if ``faverage`` is `True`. + """ n_pairs = len(source_idx) data = np.expand_dims(data, axis=0) if mode == 'cwt_morlet': @@ -478,8 +519,35 @@ def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, faverage, weights): """Compute spectral connectivity in parallel. - Input signal w is of shape (n_chans, n_tapers, n_freqs, n_times).""" + Parameters + ---------- + w : array_like, shape (n_chans, n_tapers, n_freqs, n_times) + Time-frequency data (complex signal). + method : list of str + List of connectivity metrics to compute. + kernel : array_like, shape (n_sm_fres, n_sm_times) + Smoothing kernel. + foi_idx : array_like, shape (n_foi, 2) + Upper and lower bound indices of frequency bands. + source_idx : array_like, shape (n_pairs,) + Defines the signal pairs of interest together with ``target_idx``. + target_idx : array_like, shape (n_pairs,) + Defines the signal pairs of interest together with ``source_idx``. + n_jobs : int + Number of parallel jobs. + total : int + Number of pairs of signals. + faverage : bool + Average over frequency bands. + weights : array_like, shape (n_tapers, n_freqs, n_times) + Multitaper weights. + Returns + ------- + out : array_like, shape (n_pairs, n_methods, n_freqs_out) + Connectivity estimates for each signal pair, method, and frequency or + frequency band. + """ if 'coh' in method: # psd if weights is not None: @@ -512,7 +580,36 @@ def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, n_jobs, def _pairwise_con(w, psd, x, y, method, kernel, foi_idx, faverage, weights): - """Compute spectral connectivity metrics between two signals.""" + """Compute spectral connectivity metrics between two signals. + + Parameters + ---------- + w : array_like, shape (n_chans, n_tapers, n_freqs, n_times) + Time-frequency data. + psd : array_like, shape (n_chans, n_freqs, n_times) + Power spectrum between signals ``x`` and ``y``. + x : int + Channel index. + y : int + Channel index. + method : str + Connectivity method. + kernel : array_like, shape (n_sm_fres, n_sm_times) + Smoothing kernel. + foi_idx : array_like, shape (n_foi, 2) + Upper and lower bound indices of frequency bands. + faverage : bool + Average over frequency bands. + weights : array_like, shape (n_tapers, n_freqs, n_times) | None + Multitaper weights. + + Returns + ------- + out : list + List of connectivity estimates between signals ``x`` and ``y`` + corresponding to the methods in ``method``. Each element is an array + with shape (n_freqs,) or (n_fbands) depending on ``faverage``. + """ w_x, w_y = w[x], w[y] if weights is not None: s_xy = np.sum(weights * w_x * np.conj(weights * w_y), axis=0) @@ -578,7 +675,7 @@ def _coh(s_xx, s_yy, s_xy): def _compute_csd(x, y, weights): - """Compute cross spectral density of signals x and y.""" + """Compute cross spectral density between signals x and y.""" if weights is not None: s_xy = np.sum(weights * x * np.conj(weights * y), axis=-3) s_xy = s_xy * 2 / (weights * np.conj(weights)).real.sum(axis=-3) @@ -595,15 +692,15 @@ def _foi_average(conn, foi_idx): Parameters ---------- - conn : np.ndarray - Array of shape (..., n_freqs, n_times) - foi_idx : array_like - Array of indices describing frequency bounds of shape (n_foi, 2) + conn : array_like, shape (..., n_freqs, n_times) + Connectivity estimate array. + foi_idx : array_like, shape (n_foi, 2) + Upper and lower frequency bounds of each frequency band. Returns ------- - conn_f : np.ndarray - Array of shape (..., n_foi, n_times) + conn_f : np.ndarray, shape (..., n_fbands, n_times) + Connectivity estimate array, averaged within frequency bands. """ # get the number of foi n_foi = foi_idx.shape[0]