From b2f05e0394441f6f374bbb066131c5c049500a1e Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 20 Nov 2023 00:56:12 +0100 Subject: [PATCH 01/59] [ENH] Added CaCoh --- mne_connectivity/spectral/epochs.py | 297 ++++++++++++++---- .../example_multivariate_matlab_results.pkl | Bin 3310 -> 4147 bytes .../spectral/tests/test_spectral.py | 2 +- mne_connectivity/spectral/time.py | 39 ++- 4 files changed, 261 insertions(+), 77 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index f4899e87..ee069b1e 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -441,11 +441,7 @@ def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm class _MultivariateCohEstBase(_EpochMeanMultivariateConEstBase): - """Base estimator for multivariate imag. part of coherency methods. - - See Ewald et al. (2012). NeuroImage. DOI: 10.1016/j.neuroimage.2011.11.084 - for equation references. - """ + """Base estimator for multivariate coherency methods.""" name = None accumulate_psd = False @@ -455,8 +451,8 @@ def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): n_signals, n_cons, n_freqs, n_times, n_jobs) def compute_con(self, indices, ranks, n_epochs=1): - """Compute multivariate imag. part of coherency between signals.""" - assert self.name in ['MIC', 'MIM'], ( + """Compute multivariate coherency methods.""" + assert self.name in ['CaCoh', 'MIC', 'MIM'], ( 'the class name is not recognised, please contact the ' 'mne-connectivity developers') @@ -465,7 +461,7 @@ def compute_con(self, indices, ranks, n_epochs=1): times = np.arange(n_times) freqs = np.arange(self.n_freqs) - if self.name == 'MIC': + if self.name in ['CaCoh', 'MIC']: self.patterns = np.full( (2, self.n_cons, indices[0].shape[1], self.n_freqs, n_times), np.nan) @@ -485,21 +481,19 @@ def compute_con(self, indices, ranks, n_epochs=1): C_bar, U_bar_aa, U_bar_bb = self._csd_svd( C, seed_idcs, seed_rank, target_rank) - # Eqs. 3 & 4 - E = self._compute_e(C_bar, n_seeds=U_bar_aa.shape[3]) - - if self.name == 'MIC': - self._compute_mic(E, C, seed_idcs, target_idcs, n_times, - U_bar_aa, U_bar_bb, con_i) - else: - self._compute_mim(E, seed_idcs, target_idcs, con_i) + self._compute_con_daughter(seed_idcs, target_idcs, C, C_bar, + U_bar_aa, U_bar_bb, con_i) con_i += 1 self.reshape_results() def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): - """Dimensionality reduction of CSD with SVD.""" + """Dimensionality reduction of CSD with SVD. + + Eqs. 32 & 33 of Ewald et al. (2012). NeuroImage. DOI: + 10.1016/j.neuroimage.2011.11.084 + """ n_times = csd.shape[0] n_seeds = len(seed_idcs) n_targets = csd.shape[3] - n_seeds @@ -540,19 +534,24 @@ def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): return C_bar, U_bar_aa, U_bar_bb - def _compute_e(self, csd, n_seeds): - """Compute E from the CSD.""" - C_r = np.real(csd) + def _compute_con_daughter(self, seed_idcs, target_idcs, C, C_bar, U_bar_aa, + U_bar_bb, con_i): + """Compute multivariate coherency for one connection. + + An empty method to be implemented by subclasses. + """ - parallel, parallel_compute_t, _ = parallel_func( - _mic_mim_compute_t, self.n_jobs, verbose=False) + def _compute_t(self, C_r, n_seeds): + """Compute transformation matrix, T, for frequencies (in parallel).""" + parallel, parallel_invsqrtm, _ = parallel_func( + _invsqrtm, self.n_jobs, verbose=False) # imag. part of T filled when data is rank-deficient - T = np.zeros(csd.shape, dtype=np.complex128) - for block_i in ProgressBar( - range(self.n_steps), mesg="frequency blocks"): + T = np.zeros(C_r.shape, dtype=np.complex128) + for block_i in ProgressBar(range(self.n_steps), + mesg="frequency blocks"): freqs = self._get_block_indices(block_i, self.n_freqs) - T[:, freqs] = np.array(parallel(parallel_compute_t( + T[:, freqs] = np.array(parallel(parallel_invsqrtm( C_r[:, f], T[:, f], n_seeds) for f in freqs) ).transpose(1, 0, 2, 3) @@ -562,19 +561,73 @@ def _compute_e(self, csd, n_seeds): 'and contain no NaN or infinity values; check that you are ' 'using full rank data or specify an appropriate rank for the ' 'seeds and targets that is less than or equal to their ranks') - T = np.real(T) # make T real if check passes + + return np.real(T) # make T real if check passes + + def reshape_results(self): + """Remove time dimension from results, if necessary.""" + if self.n_times == 0: + self.con_scores = self.con_scores[..., 0] + if self.patterns is not None: + self.patterns = self.patterns[..., 0] + + +def _invsqrtm(C, T, n_seeds): + """Compute inverse sqrt of CSD over times (used for CaCoh, MIC, & MIM). + + Kept as a standalone function to allow for parallelisation over CSD + frequencies. + """ + for time_i in range(C.shape[0]): + T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power( + C[time_i, :n_seeds, :n_seeds], -0.5) + T[time_i, n_seeds:, n_seeds:] = sp.linalg.fractional_matrix_power( + C[time_i, n_seeds:, n_seeds:], -0.5) + + return T + + +class _MultivariateImCohEstBase(_MultivariateCohEstBase): + """Base estimator for multivariate imag. part of coherency methods. + + See Ewald et al. (2012). NeuroImage. DOI: 10.1016/j.neuroimage.2011.11.084 + for equation references. + """ + + def _compute_con_daughter(self, seed_idcs, target_idcs, C, C_bar, U_bar_aa, + U_bar_bb, con_i): + """Compute multivariate imag. part of coherency for one connection.""" + assert self.name in ['MIC', 'MIM'], ( + 'the class name is not recognised, please contact the ' + 'mne-connectivity developers') + + # Eqs. 3 & 4 + E = self._compute_e(C_bar, n_seeds=U_bar_aa.shape[3]) + + if self.name == 'MIC': + self._compute_mic(E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, + con_i) + else: + self._compute_mim(E, seed_idcs, target_idcs, con_i) + + def _compute_e(self, C, n_seeds): + """Compute E from the CSD.""" + C_r = np.real(C) + + T = self._compute_t(C_r, n_seeds) # Eq. 4 - D = np.matmul(T, np.matmul(csd, T)) + D = np.matmul(T, np.matmul(C, T)) # E as imag. part of D between seeds and targets return np.imag(D[..., :n_seeds, n_seeds:]) - def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, - U_bar_bb, con_i): - """Compute MIC and the associated spatial patterns.""" + def _compute_mic(self, E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, + con_i): + """Compute MIC & spatial patterns for one connection.""" n_seeds = len(seed_idcs) n_targets = len(target_idcs) + n_times = C.shape[0] times = np.arange(n_times) freqs = np.arange(self.n_freqs) @@ -632,7 +685,7 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, ) / np.linalg.norm(alpha, axis=2) * np.linalg.norm(beta, axis=2)).T def _compute_mim(self, E, seed_idcs, target_idcs, con_i): - """Compute MIM (a.k.a. GIM if seeds == targets).""" + """Compute MIM (a.k.a. GIM if seeds == targets) for one connection.""" # Eq. 14 self.con_scores[con_i] = np.matmul( E, E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T @@ -644,37 +697,137 @@ def _compute_mim(self, E, seed_idcs, target_idcs, con_i): ): self.con_scores[con_i] *= 0.5 - def reshape_results(self): - """Remove time dimension from results, if necessary.""" - if self.n_times == 0: - self.con_scores = self.con_scores[..., 0] - if self.patterns is not None: - self.patterns = self.patterns[..., 0] +class _MICEst(_MultivariateImCohEstBase): + """Multivariate imaginary part of coherency (MIC) estimator.""" -def _mic_mim_compute_t(C, T, n_seeds): - """Compute T for a single frequency (used for MIC and MIM).""" - for time_i in range(C.shape[0]): - T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power( - C[time_i, :n_seeds, :n_seeds], -0.5 - ) - T[time_i, n_seeds:, n_seeds:] = sp.linalg.fractional_matrix_power( - C[time_i, n_seeds:, n_seeds:], -0.5 - ) + name = 'MIC' - return T +class _MIMEst(_MultivariateImCohEstBase): + """Multivariate interaction measure (MIM) estimator.""" -class _MICEst(_MultivariateCohEstBase): - """Multivariate imaginary part of coherency (MIC) estimator.""" + name = 'MIM' - name = "MIC" +class _CaCohEst(_MultivariateCohEstBase): + """Canonical coherence (CaCoh) estimator. -class _MIMEst(_MultivariateCohEstBase): - """Multivariate interaction measure (MIM) estimator.""" + See Vidaurre et al. (2019). NeuroImage. DOI: + 10.1016/j.neuroimage.2019.116009 for equation references. + """ + + name = 'CaCoh' + + def _compute_con_daughter(self, seed_idcs, target_idcs, C, C_bar, U_bar_aa, + U_bar_bb, con_i): + """Compute CaCoh & spatial patterns for one connection.""" + assert self.name == 'CaCoh', ( + 'the class name is not recognised, please contact the ' + 'mne-connectivity developers') + n_seeds = len(seed_idcs) + n_targets = len(target_idcs) + + C_bar_ab = C_bar[..., :n_seeds, n_seeds:] + + T = self._compute_t(np.real(C_bar), n_seeds=U_bar_aa.shape[3]) + T_aa = T[..., :n_seeds, :n_seeds] + T_bb = T[..., n_seeds:, n_seeds:] + + max_coh, max_phis = self._first_optimise_phi(C_bar_ab, T_aa, T_bb) + + max_coh, max_phis = self._final_optimise_phi(C_bar_ab, T_aa, T_bb, + max_coh, max_phis) + + self.con_scores[con_i] = max_coh.T - name = "MIM" + self._compute_patterns(max_phis, C, C_bar_ab, T_aa, T_bb, U_bar_aa, + U_bar_bb, n_seeds, n_targets, con_i) + + def _first_optimise_phi(self, C_ab, T_aa, T_bb): + """Find the rough angle at which coherence is maximised.""" + n_iters = 5 + + # starting phi values to optimise over (in radians) + phis = np.array([(iter_i + 1) / n_iters * np.pi for iter_i in + range(n_iters)]) + phis_coh = np.zeros((n_iters, *C_ab.shape[:2])) + for iter_i, iter_phi in enumerate(phis): + phi = np.full(C_ab.shape[:2], fill_value=iter_phi) + phis_coh[iter_i] = self._compute_cacoh(phi, C_ab, T_aa, T_bb) + + return np.max(phis_coh, axis=0), phis[np.argmax(phis_coh, axis=0)] + + def _final_optimise_phi(self, C_ab, T_aa, T_bb, max_coh, max_phis): + """Fine-tune the angle at which coherence is maximised.""" + n_iters = 10 + delta_phi = 1e-6 + mus = np.ones_like(max_phis) + + for _ in range(n_iters): + coh_plus = self._compute_cacoh(max_phis + delta_phi, C_ab, T_aa, + T_bb) + coh_minus = self._compute_cacoh(max_phis - delta_phi, C_ab, T_aa, + T_bb) + + f_prime = (coh_plus - coh_minus) / (2 * delta_phi) + f_prime_prime = (coh_plus + coh_minus - 2 * max_coh) / ( + delta_phi ** 2) + phis = max_phis + (-f_prime / (f_prime_prime - mus)) + phis = np.mod(phis + np.pi / 2, np.pi) - np.pi / 2 + + coh = self._compute_cacoh(phis, C_ab, T_aa, T_bb) + + greater_coh = coh > max_coh + + mus[greater_coh] /= 2 + mus[~greater_coh] *= 2 + + # update coherence and phis + max_coh[greater_coh] = coh[greater_coh] + max_phis[greater_coh] = phis[greater_coh] + + return max_coh, phis + + def _compute_cacoh(self, phis, C_ab, T_aa, T_bb): + """Compute the maximum coherence for a given set of phis.""" + # from numerator of Eq. 5 + C_ab = np.real(np.exp(-1j * np.expand_dims(phis, axis=(2, 3))) * C_ab) + + # Eq. 9 + D = np.matmul(T_aa, np.matmul(C_ab, T_bb)) + + # Eq. 12 + a = np.linalg.eigh(np.matmul(D, D.transpose(0, 1, 3, 2)))[1][..., -1] + b = np.linalg.eigh(np.matmul(D.transpose(0, 1, 3, 2), D))[1][..., -1] + + # Eq. 8 + numerator = np.einsum('ijk,ijk->ij', a, + np.matmul(D, np.expand_dims(b, axis=3))[..., 0]) + denominator = np.sqrt(np.einsum('ijk,ijk->ij', a, a) * + np.einsum('ijk,ijk->ij', b, b)) + + return np.abs(numerator / denominator) + + def _compute_patterns(self, phis, C, C_bar_ab, T_aa, T_bb, U_bar_aa, + U_bar_bb, n_seeds, n_targets, con_i): + """Compute CaCoh spatial patterns for the optimised phi.""" + C_bar_ab = np.real(np.exp(-1j * np.expand_dims(phis, axis=(2, 3))) * + C_bar_ab) + D = np.matmul(T_aa, np.matmul(C_bar_ab, T_bb)) + a = np.linalg.eigh(np.matmul(D, D.transpose(0, 1, 3, 2)))[1][..., -1] + b = np.linalg.eigh(np.matmul(D.transpose(0, 1, 3, 2), D))[1][..., -1] + + # Eq. 7 rearranged + alpha = np.matmul(T_aa, np.expand_dims(a, axis=3)) # filter for seeds + beta = np.matmul(T_bb, np.expand_dims(b, axis=3)) # filter for targets + + self.patterns[0, con_i, :n_seeds] = (np.matmul( + np.real(C[..., :n_seeds, :n_seeds]), np.matmul(U_bar_aa, alpha)) + )[..., 0].T + self.patterns[1, con_i, :n_targets] = (np.matmul( + np.real(C[..., n_seeds:, n_seeds:]), np.matmul(U_bar_bb, beta)) + )[..., 0].T class _PLVEst(_EpochMeanConEstBase): @@ -1281,7 +1434,7 @@ class _GCTREst(_GCEstBase): ############################################################################### -_multivariate_methods = ['mic', 'mim', 'gc', 'gc_tr'] +_multivariate_methods = ['cacoh', 'mic', 'mim', 'gc', 'gc_tr'] _gc_methods = ['gc', 'gc_tr'] @@ -1505,8 +1658,9 @@ def _get_and_verify_data_sizes(data, sfreq, n_signals=None, n_times=None, 'plv': _PLVEst, 'ciplv': _ciPLVEst, 'ppc': _PPCEst, 'pli': _PLIEst, 'pli2_unbiased': _PLIUnbiasedEst, 'dpli': _DPLIEst, 'wpli': _WPLIEst, - 'wpli2_debiased': _WPLIDebiasedEst, 'mic': _MICEst, - 'mim': _MIMEst, 'gc': _GCEst, 'gc_tr': _GCTREst} + 'wpli2_debiased': _WPLIDebiasedEst, 'cacoh': _CaCohEst, + 'mic': _MICEst, 'mim': _MIMEst, 'gc': _GCEst, + 'gc_tr': _GCTREst} def _check_estimators(method): @@ -1563,9 +1717,10 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, %(names)s method : str | list of str Connectivity measure(s) to compute. These can be ``['coh', 'cohy', - 'imcoh', 'mic', 'mim', 'plv', 'ciplv', 'ppc', 'pli', 'dpli', 'wpli', - 'wpli2_debiased', 'gc', 'gc_tr']``. Multivariate methods (``['mic', - 'mim', 'gc', 'gc_tr]``) cannot be called with the other methods. + 'imcoh', 'cacoh', 'mic', 'mim', 'plv', 'ciplv', 'ppc', 'pli', 'dpli', + 'wpli', 'wpli2_debiased', 'gc', 'gc_tr']``. Multivariate methods + (``['cacoh', 'mic', 'mim', 'gc', 'gc_tr]``) cannot be called with the + other methods. indices : tuple of array | None Two arrays with indices of connections for which to compute connectivity. If a bivariate method is called, each array for the seeds @@ -1627,7 +1782,7 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, Two arrays with the rank to project the seed and target data to, respectively, using singular value decomposition. If None, the rank of the data is computed and projected to. Only used if ``method`` contains - any of ``['mic', 'mim', 'gc', 'gc_tr']``. + any of ``['cacoh', 'mic', 'mim', 'gc', 'gc_tr']``. block_size : int How many connections to compute at once (higher numbers are faster but require more memory). @@ -1733,6 +1888,20 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, C = ---------------------- sqrt(E[Sxx] * E[Syy]) + 'cacoh' : Canonical Coherence (CaCoh) :footcite:`VidaurreEtAl2019` + given by: + + :math:`CaCoh=\Large{\frac{\mid\boldsymbol{a}^T\boldsymbol{D}(\Phi) + \boldsymbol{b}\mid}{\sqrt{\boldsymbol{a}^T\boldsymbol{a} + \boldsymbol{b}^T\boldsymbol{b}}}}` + + where: :math:`\boldsymbol{D}(\Phi)` is the cross-spectral density + between seeds and targets transformed for a given phase angle + :math:`\Phi`; and :math:`\boldsymbol{a}` and :math:`\boldsymbol{b}` + are eigenvectors for the seeds and targets, such that :math:`\mid + \boldsymbol{a}^T\boldsymbol{D}(\Phi)\boldsymbol{b}\mid` maximises + coherence between the seeds and targets. + 'mic' : Maximised Imaginary part of Coherency (MIC) :footcite:`EwaldEtAl2012` given by: @@ -1744,8 +1913,8 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, transformed cross-spectral density between seeds and targets; and :math:`\boldsymbol{\alpha}` and :math:`\boldsymbol{\beta}` are eigenvectors for the seeds and targets, such that - :math:`\boldsymbol{\alpha}^T \boldsymbol{E \beta}` maximises - connectivity between the seeds and targets. + :math:`\boldsymbol{\alpha}^T \boldsymbol{E \beta}` maximises the + imaginary part of coherency between the seeds and targets. 'mim' : Multivariate Interaction Measure (MIM) :footcite:`EwaldEtAl2012` given by: @@ -1794,7 +1963,7 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, :math:`GC = ln\Large{(\frac{\lvert\boldsymbol{S}_{tt}\rvert}{\lvert \boldsymbol{S}_{tt}-\boldsymbol{H}_{ts}\boldsymbol{\Sigma}_{ss - \lvert t}\boldsymbol{H}_{ts}^*\rvert}})`, + \lvert t}\boldsymbol{H}_{ts}^*\rvert}})` where: :math:`s` and :math:`t` represent the seeds and targets, respectively; :math:`\boldsymbol{H}` is the spectral transfer diff --git a/mne_connectivity/spectral/tests/data/example_multivariate_matlab_results.pkl b/mne_connectivity/spectral/tests/data/example_multivariate_matlab_results.pkl index 7b1de66583f2bf50b293647fa4b84135c9503554..dd1e2bcb4e9056ace9fa5325afcc582694a97096 100644 GIT binary patch delta 860 zcmV-i1Ec)z8M7b-fCZH(5U~a441bITV_{=&Xq0FKD3$;N00000005?g-%nNBz(4-V z`e%3@!aqi|!etb`#y|U_)S$ea!avNrFGhyzyFYBJ;?#&4$3Kt0zy&WV!ascwjK9^^ z#Xoc1=8ZI}!au-HhA}A!yg%lq7%y*xzdtKM9G5o zGe=?>Prg54cm5AXu)RM%1NN<|O0++uh_+L}^};`spN3Sj-@!k1`j9cb{JuXdR&s#n zjkZ7URyG;R1GYbC2*;edfWSZ6aqc+&`@lcCbG6?X+{Zt1P&Wh)C&WK9u~FGGQo%n- zTsm;NW5PeKyK_u1vBp2JwST(nf)&U=*4VaNKqSFG4B$Tu;*P;Tfn?mf8|1z}!BBI& z&=j;k00>{On*qWXf41$lxPTy0%kariHL`kQj{o5 m1859RPEJby|Ns9=|NsC0O8|6~VoL!_V1|@v4|J4LlyxrBeW&XH delta 17 Ycmdn2@J^DofpzL*o{g-}d6-J|06eb-Hvj+t diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index a514382b..fc120311 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -585,7 +585,7 @@ def test_multivariate_spectral_connectivity_epochs_regression(): os.path.join(fpath, 'data', 'example_multivariate_data.pkl')) sfreq = 100 indices = (np.array([[0, 1]]), np.array([[2, 3]])) - methods = ['mic', 'mim', 'gc', 'gc_tr'] + methods = ['mic', 'mim', 'cacoh', 'gc', 'gc_tr'] con = spectral_connectivity_epochs( data, method=methods, indices=indices, mode='multitaper', sfreq=sfreq, fskip=0, faverage=False, tmin=0, tmax=None, mt_bandwidth=4, diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 07dc4e57..eba724a8 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -13,13 +13,13 @@ from mne.utils import (logger, verbose) from ..base import (SpectralConnectivity, EpochSpectralConnectivity) -from .epochs import (_MICEst, _MIMEst, _GCEst, _GCTREst, _compute_freq_mask, - _check_rank_input) +from .epochs import (_CaCohEst, _MICEst, _MIMEst, _GCEst, _GCTREst, + _compute_freq_mask, _check_rank_input) from .smooth import _create_kernel, _smooth_spectra from ..utils import check_indices, _check_multivariate_indices, fill_doc -_multivariate_methods = ['mic', 'mim', 'gc', 'gc_tr'] +_multivariate_methods = ['cacoh', 'mic', 'mim', 'gc', 'gc_tr'] _gc_methods = ['gc', 'gc_tr'] @@ -51,9 +51,10 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, ``fmax`` are used. method : str | list of str Connectivity measure(s) to compute. These can be - ``['coh', 'mic', 'mim', 'plv', 'ciplv', 'pli', 'wpli', 'gc', + ``['coh', 'cacoh', 'mic', 'mim', 'plv', 'ciplv', 'pli', 'wpli', 'gc', 'gc_tr']``. These are: * 'coh' : Coherence + * 'cacoh' : Canonical Coherence (CaCoh) * 'mic' : Maximised Imaginary part of Coherency (MIC) * 'mim' : Multivariate Interaction Measure (MIM) * 'plv' : Phase-Locking Value (PLV) @@ -62,8 +63,8 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, * 'wpli' : Weighted Phase-Lag Index * 'gc' : State-space Granger Causality (GC) * 'gc_tr' : State-space GC on time-reversed signals - Multivariate methods (``['mic', 'mim', 'gc', 'gc_tr]``) cannot be - called with the other methods. + Multivariate methods (``['cacoh', 'mic', 'mim', 'gc', 'gc_tr]``) cannot + be called with the other methods. average : bool Average connectivity scores over epochs. If ``True``, output will be an instance of :class:`SpectralConnectivity`, otherwise @@ -131,7 +132,7 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, Two arrays with the rank to project the seed and target data to, respectively, using singular value decomposition. If `None`, the rank of the data is computed and projected to. Only used if ``method`` - contains any of ``['mic', 'mim', 'gc', 'gc_tr']``. + contains any of ``['cacoh', 'mic', 'mim', 'gc', 'gc_tr']``. decim : int To reduce memory usage, decimation factor after time-frequency decomposition. Returns ``tfr[…, ::decim]``. @@ -232,6 +233,20 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, C = --------------------- sqrt(E[Sxx] * E[Syy]) + 'cacoh' : Canonical Coherence (CaCoh) :footcite:`VidaurreEtAl2019` + given by: + + :math:`CaCoh=\Large{\frac{\mid\boldsymbol{a}^T\boldsymbol{D}(\Phi) + \boldsymbol{b}\mid}{\sqrt{\boldsymbol{a}^T\boldsymbol{a} + \boldsymbol{b}^T\boldsymbol{b}}}}` + + where: :math:`\boldsymbol{D}(\Phi)` is the cross-spectral density + between seeds and targets transformed for a given phase angle + :math:`\Phi`; and :math:`\boldsymbol{a}` and :math:`\boldsymbol{b}` + are eigenvectors for the seeds and targets, such that :math:`\mid + \boldsymbol{a}^T\boldsymbol{D}(\Phi)\boldsymbol{b}\mid` maximises + coherence between the seeds and targets. + 'mic' : Maximised Imaginary part of Coherency (MIC) :footcite:`EwaldEtAl2012` given by: @@ -243,8 +258,8 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, transformed cross-spectral density between seeds and targets; and :math:`\boldsymbol{\alpha}` and :math:`\boldsymbol{\beta}` are eigenvectors for the seeds and targets, such that - :math:`\boldsymbol{\alpha}^T \boldsymbol{E \beta}` maximises - connectivity between the seeds and targets. + :math:`\boldsymbol{\alpha}^T \boldsymbol{E \beta}` maximises the + imaginary part of coherency between the seeds and targets. 'mim' : Multivariate Interaction Measure (MIM) :footcite:`EwaldEtAl2012` given by: @@ -279,7 +294,7 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, :math:`GC = ln\Large{(\frac{\lvert\boldsymbol{S}_{tt}\rvert}{\lvert \boldsymbol{S}_{tt}-\boldsymbol{H}_{ts}\boldsymbol{\Sigma}_{ss - \lvert t}\boldsymbol{H}_{ts}^*\rvert}})`, + \lvert t}\boldsymbol{H}_{ts}^*\rvert}})` where: :math:`s` and :math:`t` represent the seeds and targets, respectively; :math:`\boldsymbol{H}` is the spectral transfer @@ -913,8 +928,8 @@ def _multivariate_con(w, seeds, targets, signals_use, method, kernel, foi_idx, csd = np.array(csd) # initialise connectivity estimators and add CSD information - conn_class = {'mic': _MICEst, 'mim': _MIMEst, 'gc': _GCEst, - 'gc_tr': _GCTREst} + conn_class = {'cacoh': _CaCohEst, 'mic': _MICEst, 'mim': _MIMEst, + 'gc': _GCEst, 'gc_tr': _GCTREst} conn = [] for m in method: call_params = {'n_signals': len(signals_use), 'n_cons': len(seeds), From 027701f7d53306301e270598e48ae90a25ddf7eb Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 21 Nov 2023 18:56:33 +0100 Subject: [PATCH 02/59] [MAINT] Switched to numpy for computing phis --- mne_connectivity/spectral/epochs.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index ee069b1e..3f06a32d 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -749,8 +749,7 @@ def _first_optimise_phi(self, C_ab, T_aa, T_bb): n_iters = 5 # starting phi values to optimise over (in radians) - phis = np.array([(iter_i + 1) / n_iters * np.pi for iter_i in - range(n_iters)]) + phis = np.linspace(np.pi / n_iters, np.pi, n_iters) phis_coh = np.zeros((n_iters, *C_ab.shape[:2])) for iter_i, iter_phi in enumerate(phis): phi = np.full(C_ab.shape[:2], fill_value=iter_phi) From a743c754e994702e5a90908d43cb6d06b0417852 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 21 Nov 2023 18:57:08 +0100 Subject: [PATCH 03/59] [DOC] Improved internal CaCoh comments --- mne_connectivity/spectral/epochs.py | 74 ++++++++++++++++++++--------- 1 file changed, 51 insertions(+), 23 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 3f06a32d..01d7a35c 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -441,7 +441,12 @@ def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm class _MultivariateCohEstBase(_EpochMeanMultivariateConEstBase): - """Base estimator for multivariate coherency methods.""" + """Base estimator for multivariate coherency methods. + + See: + - Ewald et al. (2012). NeuroImage. DOI: 10.1016/j.neuroimage.2011.11.084 + - Vidaurre et al. (2019). NeuroImage. DOI: 10.1016/j.neuroimage.2019.116009 + """ name = None accumulate_psd = False @@ -477,7 +482,7 @@ def compute_con(self, indices, ranks, n_epochs=1): C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] - # Eqs. 32 & 33 + # Eqs. 32 & 33 of Ewald et al.; Eq. 15 of Vidaurre et al. C_bar, U_bar_aa, U_bar_bb = self._csd_svd( C, seed_idcs, seed_rank, target_rank) @@ -489,11 +494,7 @@ def compute_con(self, indices, ranks, n_epochs=1): self.reshape_results() def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): - """Dimensionality reduction of CSD with SVD. - - Eqs. 32 & 33 of Ewald et al. (2012). NeuroImage. DOI: - 10.1016/j.neuroimage.2011.11.084 - """ + """Dimensionality reduction of CSD with SVD.""" n_times = csd.shape[0] n_seeds = len(seed_idcs) n_targets = csd.shape[3] - n_seeds @@ -503,7 +504,7 @@ def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): C_bb = csd[..., n_seeds:, n_seeds:] C_ba = csd[..., n_seeds:, :n_seeds] - # Eq. 32 + # Eqs. 32 (Ewald et al.) & 15 (Vidaurre et al.) if seed_rank != n_seeds: U_aa = np.linalg.svd(np.real(C_aa), full_matrices=False)[0] U_bar_aa = U_aa[..., :seed_rank] @@ -520,7 +521,7 @@ def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): np.identity(n_targets), (n_times, self.n_freqs) + (n_targets, n_targets)) - # Eq. 33 + # Eq. 33 (Ewald et al.) C_bar_aa = np.matmul( U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_aa, U_bar_aa)) C_bar_ab = np.matmul( @@ -542,7 +543,10 @@ def _compute_con_daughter(self, seed_idcs, target_idcs, C, C_bar, U_bar_aa, """ def _compute_t(self, C_r, n_seeds): - """Compute transformation matrix, T, for frequencies (in parallel).""" + """Compute transformation matrix, T, for frequencies (in parallel). + + Eq. 3 of Ewald et al.; part of Eq. 9 of Vidaurre et al. + """ parallel, parallel_invsqrtm, _ = parallel_func( _invsqrtm, self.n_jobs, verbose=False) @@ -577,6 +581,9 @@ def _invsqrtm(C, T, n_seeds): Kept as a standalone function to allow for parallelisation over CSD frequencies. + + See Eq. 3 of Ewald et al. (2012). NeuroImage. DOI: + 10.1016/j.neuroimage.2011.11.084 """ for time_i in range(C.shape[0]): T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power( @@ -614,6 +621,7 @@ def _compute_e(self, C, n_seeds): """Compute E from the CSD.""" C_r = np.real(C) + # Eq. 3 T = self._compute_t(C_r, n_seeds) # Eq. 4 @@ -730,9 +738,10 @@ def _compute_con_daughter(self, seed_idcs, target_idcs, C, C_bar, U_bar_aa, C_bar_ab = C_bar[..., :n_seeds, n_seeds:] + # Same as Eq. 3 of Ewald et al. (2012) T = self._compute_t(np.real(C_bar), n_seeds=U_bar_aa.shape[3]) - T_aa = T[..., :n_seeds, :n_seeds] - T_bb = T[..., n_seeds:, n_seeds:] + T_aa = T[..., :n_seeds, :n_seeds] # left term in Eq. 9 + T_bb = T[..., n_seeds:, n_seeds:] # right term in Eq. 9 max_coh, max_phis = self._first_optimise_phi(C_bar_ab, T_aa, T_bb) @@ -745,7 +754,7 @@ def _compute_con_daughter(self, seed_idcs, target_idcs, C, C_bar, U_bar_aa, U_bar_bb, n_seeds, n_targets, con_i) def _first_optimise_phi(self, C_ab, T_aa, T_bb): - """Find the rough angle at which coherence is maximised.""" + """Find the rough angle, phi, at which coherence is maximised.""" n_iters = 5 # starting phi values to optimise over (in radians) @@ -758,42 +767,58 @@ def _first_optimise_phi(self, C_ab, T_aa, T_bb): return np.max(phis_coh, axis=0), phis[np.argmax(phis_coh, axis=0)] def _final_optimise_phi(self, C_ab, T_aa, T_bb, max_coh, max_phis): - """Fine-tune the angle at which coherence is maximised.""" - n_iters = 10 + """Fine-tune the angle at which coherence is maximised. + + Uses a 2nd order Taylor expansion to approximate change in coherence + w.r.t. phi, and determining the next phi to evaluate coherence on (over + a total of 10 iterations). + + Depending on how the new phi affects coherence, the step size for the + subsequent iteration is adjusted, like that in the Levenberg-Marquardt + algorithm. + + Each time-freq. entry of coherence has its own corresponding phi. + """ + n_iters = 10 # sufficient for (close to) exact solution delta_phi = 1e-6 - mus = np.ones_like(max_phis) + mus = np.ones_like(max_phis) # optimisation step size for _ in range(n_iters): + # 2nd order Taylor expansion around phi coh_plus = self._compute_cacoh(max_phis + delta_phi, C_ab, T_aa, T_bb) coh_minus = self._compute_cacoh(max_phis - delta_phi, C_ab, T_aa, T_bb) - f_prime = (coh_plus - coh_minus) / (2 * delta_phi) f_prime_prime = (coh_plus + coh_minus - 2 * max_coh) / ( delta_phi ** 2) + + # determine new phi to test phis = max_phis + (-f_prime / (f_prime_prime - mus)) + # bound phi in range [-pi, pi] phis = np.mod(phis + np.pi / 2, np.pi) - np.pi / 2 coh = self._compute_cacoh(phis, C_ab, T_aa, T_bb) + # find where new phi increases coh & update these values greater_coh = coh > max_coh + max_coh[greater_coh] = coh[greater_coh] + max_phis[greater_coh] = phis[greater_coh] + # update step size mus[greater_coh] /= 2 mus[~greater_coh] *= 2 - # update coherence and phis - max_coh[greater_coh] = coh[greater_coh] - max_phis[greater_coh] = phis[greater_coh] - return max_coh, phis def _compute_cacoh(self, phis, C_ab, T_aa, T_bb): """Compute the maximum coherence for a given set of phis.""" # from numerator of Eq. 5 + # for a given CSD entry, projects it onto a span with angle phi, such + # that the magnitude of the projected line is captured in the real part C_ab = np.real(np.exp(-1j * np.expand_dims(phis, axis=(2, 3))) * C_ab) - # Eq. 9 + # Eq. 9; T_aa/bb is sqrt(inv(real(C_aa/bb))) D = np.matmul(T_aa, np.matmul(C_ab, T_bb)) # Eq. 12 @@ -817,13 +842,16 @@ def _compute_patterns(self, phis, C, C_bar_ab, T_aa, T_bb, U_bar_aa, a = np.linalg.eigh(np.matmul(D, D.transpose(0, 1, 3, 2)))[1][..., -1] b = np.linalg.eigh(np.matmul(D.transpose(0, 1, 3, 2), D))[1][..., -1] - # Eq. 7 rearranged + # Eq. 7 rearranged - multiply both sides by sqrt(inv(real(C_aa/bb))) alpha = np.matmul(T_aa, np.expand_dims(a, axis=3)) # filter for seeds beta = np.matmul(T_bb, np.expand_dims(b, axis=3)) # filter for targets + # Eq. 14; U_bar inclusion follows Eqs. 46 & 47 of Ewald et al. (2012) + # seed spatial patterns self.patterns[0, con_i, :n_seeds] = (np.matmul( np.real(C[..., :n_seeds, :n_seeds]), np.matmul(U_bar_aa, alpha)) )[..., 0].T + # target spatial patterns self.patterns[1, con_i, :n_targets] = (np.matmul( np.real(C[..., n_seeds:, n_seeds:]), np.matmul(U_bar_bb, beta)) )[..., 0].T From b755dceecf81a4904849ea84d293cb5d5a152b02 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 27 Nov 2023 16:17:25 +0100 Subject: [PATCH 04/59] [BUG] Fix incorrect CaCoh n_seeds after SVD --- mne_connectivity/spectral/epochs.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 4852005c..65fbda42 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -740,12 +740,14 @@ def _compute_con_daughter(self, seed_idcs, target_idcs, C, C_bar, U_bar_aa, n_seeds = len(seed_idcs) n_targets = len(target_idcs) - C_bar_ab = C_bar[..., :n_seeds, n_seeds:] + rank_seeds = U_bar_aa.shape[3] # n_seeds after SVD + + C_bar_ab = C_bar[..., :rank_seeds, rank_seeds:] # Same as Eq. 3 of Ewald et al. (2012) - T = self._compute_t(np.real(C_bar), n_seeds=U_bar_aa.shape[3]) - T_aa = T[..., :n_seeds, :n_seeds] # left term in Eq. 9 - T_bb = T[..., n_seeds:, n_seeds:] # right term in Eq. 9 + T = self._compute_t(np.real(C_bar), n_seeds=rank_seeds) + T_aa = T[..., :rank_seeds, :rank_seeds] # left term in Eq. 9 + T_bb = T[..., rank_seeds:, rank_seeds:] # right term in Eq. 9 max_coh, max_phis = self._first_optimise_phi(C_bar_ab, T_aa, T_bb) From 8fa75d26499bd57f99b0aae2a4bc6803b7c5b421 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 28 Nov 2023 16:18:14 +0100 Subject: [PATCH 05/59] [MAINT] Updated authorship --- mne_connectivity/spectral/tests/data/README.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mne_connectivity/spectral/tests/data/README.md b/mne_connectivity/spectral/tests/data/README.md index ea9da2bd..5af5d6b8 100644 --- a/mne_connectivity/spectral/tests/data/README.md +++ b/mne_connectivity/spectral/tests/data/README.md @@ -1,7 +1,8 @@ -Author: Thomas S. Binns +Authors: Thomas S. Binns + Mohammad Orabe The files found here are used for the regression test of the multivariate -connectivity methods for MIC, MIM, GC, and TRGC +connectivity methods for CaCoh, MIC, MIM, GC, and TRGC (`test_multivariate_spectral_connectivity_epochs_regression()` of `test_spectral.py`). @@ -9,8 +10,8 @@ connectivity methods for MIC, MIM, GC, and TRGC data with 15 epochs and 200 timepoints per epoch. Connectivity was computed in MATLAB using the original implementations of these methods and saved as a dictionary in `example_multivariate_matlab_results.pkl`. A publicly-available -implementation of the methods in MATLAB can be found here: -https://github.com/sccn/roiconnect. +implementation of the methods in MATLAB (except CaCoh) can be found here: +https://github.com/sccn/roiconnect. As the MNE code for computing the cross-spectral density matrix is not available in MATLAB, the CSD matrix was computed using MNE and then loaded into From 949ac3f61635e8766a1517e1f617590e736e09ec Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 28 Nov 2023 16:31:11 +0100 Subject: [PATCH 06/59] [MAINT] Updated authorship --- mne_connectivity/spectral/tests/data/README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mne_connectivity/spectral/tests/data/README.md b/mne_connectivity/spectral/tests/data/README.md index 5af5d6b8..e50ed41a 100644 --- a/mne_connectivity/spectral/tests/data/README.md +++ b/mne_connectivity/spectral/tests/data/README.md @@ -1,5 +1,6 @@ -Authors: Thomas S. Binns - Mohammad Orabe +Authors: +- Thomas S. Binns +- Mohammad Orabe The files found here are used for the regression test of the multivariate connectivity methods for CaCoh, MIC, MIM, GC, and TRGC From bc491a85c7d322595f5da6ee976de5d4c9a1606d Mon Sep 17 00:00:00 2001 From: Mohammad Date: Thu, 30 Nov 2023 18:02:49 +0100 Subject: [PATCH 07/59] fix: update class methods for cacoh --- .../spectral/epochs_multivariate.py | 71 +------------------ 1 file changed, 1 insertion(+), 70 deletions(-) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index 6390a5e9..bdc7afce 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -468,17 +468,7 @@ def _compute_con_daughter(self, seed_idcs, target_idcs, C, C_bar, U_bar_aa, U_bar_bb, n_seeds, n_targets, con_i) def _first_optimise_phi(self, C_ab, T_aa, T_bb): - """Find the rough angle, phi, at which coherence is maximised.""" - n_iters = 5 - - # starting phi values to optimise over (in radians) - phis = np.linspace(np.pi / n_iters, np.pi, n_iters) - phis_coh = np.zeros((n_iters, *C_ab.shape[:2])) - for iter_i, iter_phi in enumerate(phis): - phi = np.full(C_ab.shape[:2], fill_value=iter_phi) - phis_coh[iter_i] = self._compute_cacoh(phi, C_ab, T_aa, T_bb) - - return np.max(phis_coh, axis=0), phis[np.argmax(phis_coh, axis=0)] + pass def _final_optimise_phi(self, C_ab, T_aa, T_bb, max_coh, max_phis): """Fine-tune the angle at which coherence is maximised. @@ -497,33 +487,6 @@ def _final_optimise_phi(self, C_ab, T_aa, T_bb, max_coh, max_phis): delta_phi = 1e-6 mus = np.ones_like(max_phis) # optimisation step size - for _ in range(n_iters): - # 2nd order Taylor expansion around phi - coh_plus = self._compute_cacoh(max_phis + delta_phi, C_ab, T_aa, - T_bb) - coh_minus = self._compute_cacoh(max_phis - delta_phi, C_ab, T_aa, - T_bb) - f_prime = (coh_plus - coh_minus) / (2 * delta_phi) - f_prime_prime = (coh_plus + coh_minus - 2 * max_coh) / ( - delta_phi ** 2) - - # determine new phi to test - phis = max_phis + (-f_prime / (f_prime_prime - mus)) - # bound phi in range [-pi, pi] - phis = np.mod(phis + np.pi / 2, np.pi) - np.pi / 2 - - coh = self._compute_cacoh(phis, C_ab, T_aa, T_bb) - - # find where new phi increases coh & update these values - greater_coh = coh > max_coh - max_coh[greater_coh] = coh[greater_coh] - max_phis[greater_coh] = phis[greater_coh] - - # update step size - mus[greater_coh] /= 2 - mus[~greater_coh] *= 2 - - return max_coh, phis def _compute_cacoh(self, phis, C_ab, T_aa, T_bb): """Compute the maximum coherence for a given set of phis.""" @@ -532,43 +495,11 @@ def _compute_cacoh(self, phis, C_ab, T_aa, T_bb): # that the magnitude of the projected line is captured in the real part C_ab = np.real(np.exp(-1j * np.expand_dims(phis, axis=(2, 3))) * C_ab) - # Eq. 9; T_aa/bb is sqrt(inv(real(C_aa/bb))) - D = np.matmul(T_aa, np.matmul(C_ab, T_bb)) - - # Eq. 12 - a = np.linalg.eigh(np.matmul(D, D.transpose(0, 1, 3, 2)))[1][..., -1] - b = np.linalg.eigh(np.matmul(D.transpose(0, 1, 3, 2), D))[1][..., -1] - - # Eq. 8 - numerator = np.einsum('ijk,ijk->ij', a, - np.matmul(D, np.expand_dims(b, axis=3))[..., 0]) - denominator = np.sqrt(np.einsum('ijk,ijk->ij', a, a) * - np.einsum('ijk,ijk->ij', b, b)) - - return np.abs(numerator / denominator) - def _compute_patterns(self, phis, C, C_bar_ab, T_aa, T_bb, U_bar_aa, U_bar_bb, n_seeds, n_targets, con_i): """Compute CaCoh spatial patterns for the optimised phi.""" C_bar_ab = np.real(np.exp(-1j * np.expand_dims(phis, axis=(2, 3))) * C_bar_ab) - D = np.matmul(T_aa, np.matmul(C_bar_ab, T_bb)) - a = np.linalg.eigh(np.matmul(D, D.transpose(0, 1, 3, 2)))[1][..., -1] - b = np.linalg.eigh(np.matmul(D.transpose(0, 1, 3, 2), D))[1][..., -1] - - # Eq. 7 rearranged - multiply both sides by sqrt(inv(real(C_aa/bb))) - alpha = np.matmul(T_aa, np.expand_dims(a, axis=3)) # filter for seeds - beta = np.matmul(T_bb, np.expand_dims(b, axis=3)) # filter for targets - - # Eq. 14; U_bar inclusion follows Eqs. 46 & 47 of Ewald et al. (2012) - # seed spatial patterns - self.patterns[0, con_i, :n_seeds] = (np.matmul( - np.real(C[..., :n_seeds, :n_seeds]), np.matmul(U_bar_aa, alpha)) - )[..., 0].T - # target spatial patterns - self.patterns[1, con_i, :n_targets] = (np.matmul( - np.real(C[..., n_seeds:, n_seeds:]), np.matmul(U_bar_bb, beta)) - )[..., 0].T class _GCEstBase(_EpochMeanMultivariateConEstBase): From 0960238516ff85ee601c5c5144aaa124665437d1 Mon Sep 17 00:00:00 2001 From: Mohammad Date: Thu, 30 Nov 2023 18:10:56 +0100 Subject: [PATCH 08/59] fix: implement class methods for cacoh --- .../spectral/epochs_multivariate.py | 71 ++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index bdc7afce..6390a5e9 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -468,7 +468,17 @@ def _compute_con_daughter(self, seed_idcs, target_idcs, C, C_bar, U_bar_aa, U_bar_bb, n_seeds, n_targets, con_i) def _first_optimise_phi(self, C_ab, T_aa, T_bb): - pass + """Find the rough angle, phi, at which coherence is maximised.""" + n_iters = 5 + + # starting phi values to optimise over (in radians) + phis = np.linspace(np.pi / n_iters, np.pi, n_iters) + phis_coh = np.zeros((n_iters, *C_ab.shape[:2])) + for iter_i, iter_phi in enumerate(phis): + phi = np.full(C_ab.shape[:2], fill_value=iter_phi) + phis_coh[iter_i] = self._compute_cacoh(phi, C_ab, T_aa, T_bb) + + return np.max(phis_coh, axis=0), phis[np.argmax(phis_coh, axis=0)] def _final_optimise_phi(self, C_ab, T_aa, T_bb, max_coh, max_phis): """Fine-tune the angle at which coherence is maximised. @@ -487,6 +497,33 @@ def _final_optimise_phi(self, C_ab, T_aa, T_bb, max_coh, max_phis): delta_phi = 1e-6 mus = np.ones_like(max_phis) # optimisation step size + for _ in range(n_iters): + # 2nd order Taylor expansion around phi + coh_plus = self._compute_cacoh(max_phis + delta_phi, C_ab, T_aa, + T_bb) + coh_minus = self._compute_cacoh(max_phis - delta_phi, C_ab, T_aa, + T_bb) + f_prime = (coh_plus - coh_minus) / (2 * delta_phi) + f_prime_prime = (coh_plus + coh_minus - 2 * max_coh) / ( + delta_phi ** 2) + + # determine new phi to test + phis = max_phis + (-f_prime / (f_prime_prime - mus)) + # bound phi in range [-pi, pi] + phis = np.mod(phis + np.pi / 2, np.pi) - np.pi / 2 + + coh = self._compute_cacoh(phis, C_ab, T_aa, T_bb) + + # find where new phi increases coh & update these values + greater_coh = coh > max_coh + max_coh[greater_coh] = coh[greater_coh] + max_phis[greater_coh] = phis[greater_coh] + + # update step size + mus[greater_coh] /= 2 + mus[~greater_coh] *= 2 + + return max_coh, phis def _compute_cacoh(self, phis, C_ab, T_aa, T_bb): """Compute the maximum coherence for a given set of phis.""" @@ -495,11 +532,43 @@ def _compute_cacoh(self, phis, C_ab, T_aa, T_bb): # that the magnitude of the projected line is captured in the real part C_ab = np.real(np.exp(-1j * np.expand_dims(phis, axis=(2, 3))) * C_ab) + # Eq. 9; T_aa/bb is sqrt(inv(real(C_aa/bb))) + D = np.matmul(T_aa, np.matmul(C_ab, T_bb)) + + # Eq. 12 + a = np.linalg.eigh(np.matmul(D, D.transpose(0, 1, 3, 2)))[1][..., -1] + b = np.linalg.eigh(np.matmul(D.transpose(0, 1, 3, 2), D))[1][..., -1] + + # Eq. 8 + numerator = np.einsum('ijk,ijk->ij', a, + np.matmul(D, np.expand_dims(b, axis=3))[..., 0]) + denominator = np.sqrt(np.einsum('ijk,ijk->ij', a, a) * + np.einsum('ijk,ijk->ij', b, b)) + + return np.abs(numerator / denominator) + def _compute_patterns(self, phis, C, C_bar_ab, T_aa, T_bb, U_bar_aa, U_bar_bb, n_seeds, n_targets, con_i): """Compute CaCoh spatial patterns for the optimised phi.""" C_bar_ab = np.real(np.exp(-1j * np.expand_dims(phis, axis=(2, 3))) * C_bar_ab) + D = np.matmul(T_aa, np.matmul(C_bar_ab, T_bb)) + a = np.linalg.eigh(np.matmul(D, D.transpose(0, 1, 3, 2)))[1][..., -1] + b = np.linalg.eigh(np.matmul(D.transpose(0, 1, 3, 2), D))[1][..., -1] + + # Eq. 7 rearranged - multiply both sides by sqrt(inv(real(C_aa/bb))) + alpha = np.matmul(T_aa, np.expand_dims(a, axis=3)) # filter for seeds + beta = np.matmul(T_bb, np.expand_dims(b, axis=3)) # filter for targets + + # Eq. 14; U_bar inclusion follows Eqs. 46 & 47 of Ewald et al. (2012) + # seed spatial patterns + self.patterns[0, con_i, :n_seeds] = (np.matmul( + np.real(C[..., :n_seeds, :n_seeds]), np.matmul(U_bar_aa, alpha)) + )[..., 0].T + # target spatial patterns + self.patterns[1, con_i, :n_targets] = (np.matmul( + np.real(C[..., n_seeds:, n_seeds:]), np.matmul(U_bar_bb, beta)) + )[..., 0].T class _GCEstBase(_EpochMeanMultivariateConEstBase): From b4230cc0f6ed80615dd3eb8b39f25adf53de9ea6 Mon Sep 17 00:00:00 2001 From: Mohammad Date: Thu, 30 Nov 2023 18:40:07 +0100 Subject: [PATCH 09/59] Add email address --- mne_connectivity/spectral/epochs_multivariate.py | 2 +- mne_connectivity/spectral/tests/data/README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index 6390a5e9..4957d001 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -4,7 +4,7 @@ # Thomas S. Binns # Tien D. Nguyen # Richard M. Köhler -# Mohammad Orabe +# Mohammad Orabe # # License: BSD (3-clause) diff --git a/mne_connectivity/spectral/tests/data/README.md b/mne_connectivity/spectral/tests/data/README.md index e50ed41a..4c62ebb6 100644 --- a/mne_connectivity/spectral/tests/data/README.md +++ b/mne_connectivity/spectral/tests/data/README.md @@ -1,6 +1,6 @@ Authors: - Thomas S. Binns -- Mohammad Orabe +- Mohammad Orabe The files found here are used for the regression test of the multivariate connectivity methods for CaCoh, MIC, MIM, GC, and TRGC From f688c337c480fcea27459a9878b0ba373b114a66 Mon Sep 17 00:00:00 2001 From: Mohammad Date: Fri, 1 Dec 2023 03:44:37 +0100 Subject: [PATCH 10/59] chore: reuse multivar tests for cacoh method --- .../spectral/tests/test_spectral.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 62c8c845..3c11af0e 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -414,7 +414,7 @@ def test_spectral_connectivity(method, mode): assert (out_lens[0] == 10) -@pytest.mark.parametrize('method', ['mic', 'mim', 'gc']) +@pytest.mark.parametrize('method', ['cacoh', 'mic', 'mim', 'gc']) def test_spectral_connectivity_epochs_multivariate(method): """Test over-epoch multivariate connectivity methods.""" mode = 'multitaper' # stick with single mode in interest of time @@ -456,7 +456,7 @@ def test_spectral_connectivity_epochs_multivariate(method): bidx = (freqs.index(fstart - trans_bandwidth * 2), freqs.index(fend + trans_bandwidth * 2) + 1) - if method in ['mic', 'mim']: + if method in ['cacoh', 'mic', 'mim']: lower_t = 0.2 upper_t = 0.5 @@ -497,12 +497,12 @@ def test_spectral_connectivity_epochs_multivariate(method): assert np.allclose(trgc[0, bidx[1]:].mean(), 0, atol=lower_t) # check all-to-all conn. computed for MIC/MIM when no indices given - if method in ['mic', 'mim']: + if method in ['cacoh', 'mic', 'mim']: con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=None, sfreq=sfreq) assert con.indices is None assert con.n_nodes == n_signals - if method == 'mic': + if method in ['cacoh', 'mic']: assert np.array(con.attrs['patterns']).shape[2] == n_signals # check ragged indices padded correctly @@ -513,7 +513,7 @@ def test_spectral_connectivity_epochs_multivariate(method): np.array([np.array([[0, -1]]), np.array([[1, 2]])])) # check shape of MIC patterns - if method == 'mic': + if method in ['cacoh', 'mic']: for mode in ['multitaper', 'cwt_morlet']: con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, @@ -604,7 +604,7 @@ def test_multivariate_spectral_connectivity_epochs_regression(): @pytest.mark.parametrize( - 'method', ['mic', 'mim', 'gc', 'gc_tr', ['mic', 'mim', 'gc', 'gc_tr']]) + 'method', ['cacoh', 'mic', 'mim', 'gc', 'gc_tr', ['cacoh', 'mic', 'mim', 'gc', 'gc_tr']]) @pytest.mark.parametrize('mode', ['multitaper', 'fourier', 'cwt_morlet']) def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): """Test error catching for multivar. freq.-domain connectivity methods.""" @@ -680,7 +680,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): gc_n_lags=10, cwt_freqs=cwt_freqs) assert rank_con.attrs["rank"] == (np.array([1]), np.array([1])) - if method in ['mic', 'mim']: + if method in ['cacoh', 'mic', 'mim']: # check rank-deficient transformation matrix caught with pytest.raises(RuntimeError, match='the transformation matrix'): @@ -731,7 +731,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): cwt_freqs=cwt_freqs) -@pytest.mark.parametrize('method', ['mic', 'mim', 'gc', 'gc_tr']) +@pytest.mark.parametrize('method', ['cacoh', 'mic', 'mim', 'gc', 'gc_tr']) def test_multivar_spectral_connectivity_parallel(method): """Test multivar. freq.-domain connectivity methods run in parallel.""" sfreq = 50. @@ -862,7 +862,7 @@ def test_epochs_tmin_tmax(kind): @pytest.mark.parametrize( - 'method', ['coh', 'mic', 'mim', 'plv', 'pli', 'wpli', 'ciplv']) + 'method', ['cacoh', 'coh', 'mic', 'mim', 'plv', 'pli', 'wpli', 'ciplv']) @pytest.mark.parametrize('mode', ['cwt_morlet', 'multitaper']) @pytest.mark.parametrize('data_option', ['sync', 'random']) def test_spectral_connectivity_time_phaselocked(method, mode, data_option): @@ -890,7 +890,7 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): n_times) data[i, c] = np.squeeze(np.sin(x)) - multivar_methods = ['mic', 'mim'] + multivar_methods = ['cacoh', 'mic', 'mim'] # the frequency band should contain the frequency at which there is a # hypothesized "connection" @@ -900,14 +900,14 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): con = spectral_connectivity_time( data, freqs, method=method, mode=mode, sfreq=sfreq, fmin=freq_band_low_limit, fmax=freq_band_high_limit, n_jobs=1, - faverage=True if method != 'mic' else False, - average=True if method != 'mic' else False, sm_times=0) + faverage=True if method not in ['cacoh', 'mic'] else False, + average=True if method not in ['cacoh', 'mic'] else False, sm_times=0) con_matrix = con.get_data() # MIC values can be pos. and neg., so must be averaged after taking the # absolute values for the test to work if method in multivar_methods: - if method == 'mic': + if method in ['cacoh', 'mic']: con_matrix = np.mean(np.abs(con_matrix), axis=(0, 2)) assert con.shape == (n_epochs, 1, len(con.freqs)) else: @@ -1139,7 +1139,7 @@ def test_spectral_connectivity_time_padding(method, mode, padding): for idx, jdx in triu_inds) -@pytest.mark.parametrize('method', ['mic', 'mim', 'gc', 'gc_tr']) +@pytest.mark.parametrize('method', ['cacoh', 'mic', 'mim', 'gc', 'gc_tr']) @pytest.mark.parametrize('average', [True, False]) @pytest.mark.parametrize('faverage', [True, False]) def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): @@ -1169,7 +1169,7 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): assert con.shape == tuple(con_shape) # check shape of MIC patterns are correct - if method == 'mic': + if method in ['cacoh', 'mic']: for indices_type in ['full', 'ragged']: if indices_type == 'full': indices = (np.array([[0, 1]]), np.array([[2, 3]])) @@ -1200,7 +1200,7 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): (np.array([[0, 1]]), np.array([[2, -1]])))) -@pytest.mark.parametrize('method', ['mic', 'mim', 'gc', 'gc_tr']) +@pytest.mark.parametrize('method', ['cacoh', 'mic', 'mim', 'gc', 'gc_tr']) @pytest.mark.parametrize('mode', ['multitaper', 'cwt_morlet']) def test_multivar_spectral_connectivity_time_error_catch(method, mode): """Test error catching for time-resolved multivar. connectivity methods.""" @@ -1258,12 +1258,12 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): mode=mode, rank=too_much_rank) # check all-to-all conn. computed for MIC/MIM when no indices given - if method in ['mic', 'mim']: + if method in ['cacoh', 'mic', 'mim']: con = spectral_connectivity_time( data, freqs, method=method, indices=None, sfreq=sfreq, mode=mode) assert con.indices is None assert con.n_nodes == n_signals - if method == 'mic': + if method == ['cacoh', 'mic']: assert np.array(con.attrs['patterns']).shape[3] == n_signals if method in ['gc', 'gc_tr']: @@ -1320,7 +1320,7 @@ def test_multivar_save_load(tmp_path): ragged_indices = (np.array([[0, 1]]), np.array([[2]])) for indices in [non_ragged_indices, ragged_indices]: con = spectral_connectivity_epochs( - epochs, method=['mic', 'mim', 'gc', 'gc_tr'], indices=indices, + epochs, method=['cacoh', 'mic', 'mim', 'gc', 'gc_tr'], indices=indices, sfreq=sfreq, fmin=10, fmax=30) for this_con in con: this_con.save(tmp_file) @@ -1378,7 +1378,7 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): assert con.indices is None and read_con.indices is None -@pytest.mark.parametrize("method", ['mic', 'mim', 'gc', 'gc_tr']) +@pytest.mark.parametrize("method", ['cacoh', 'mic', 'mim', 'gc', 'gc_tr']) @pytest.mark.parametrize("indices", [None, (np.array([[0, 1]]), np.array([[2, 3]]))]) def test_multivar_spectral_connectivity_indices_roundtrip_io( From 8519442fab7cad3751c3a169d50e7dc7ab980fcf Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Fri, 1 Dec 2023 19:04:10 +0100 Subject: [PATCH 11/59] [BUG] Fixed tests and flake --- .../spectral/tests/test_spectral.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 3c11af0e..88545e0b 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -604,7 +604,8 @@ def test_multivariate_spectral_connectivity_epochs_regression(): @pytest.mark.parametrize( - 'method', ['cacoh', 'mic', 'mim', 'gc', 'gc_tr', ['cacoh', 'mic', 'mim', 'gc', 'gc_tr']]) + 'method', ['cacoh', 'mic', 'mim', 'gc', 'gc_tr', + ['cacoh', 'mic', 'mim', 'gc', 'gc_tr']]) @pytest.mark.parametrize('mode', ['multitaper', 'fourier', 'cwt_morlet']) def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): """Test error catching for multivar. freq.-domain connectivity methods.""" @@ -862,7 +863,7 @@ def test_epochs_tmin_tmax(kind): @pytest.mark.parametrize( - 'method', ['cacoh', 'coh', 'mic', 'mim', 'plv', 'pli', 'wpli', 'ciplv']) + 'method', ['coh', 'cacoh', 'mic', 'mim', 'plv', 'pli', 'wpli', 'ciplv']) @pytest.mark.parametrize('mode', ['cwt_morlet', 'multitaper']) @pytest.mark.parametrize('data_option', ['sync', 'random']) def test_spectral_connectivity_time_phaselocked(method, mode, data_option): @@ -870,7 +871,7 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): data.""" rng = np.random.default_rng(0) n_epochs = 5 - n_channels = 3 + n_channels = 4 n_times = 1000 sfreq = 250 data = np.zeros((n_epochs, n_channels, n_times)) @@ -892,22 +893,29 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): multivar_methods = ['cacoh', 'mic', 'mim'] + if method == 'cacoh': + # CaCoh within set of signals will always be 1, so need to specify + # distinct seeds and targets + indices = ([[0, 1]], [[2, 3]]) + else: + indices = None + # the frequency band should contain the frequency at which there is a # hypothesized "connection" freq_band_low_limit = (8.) freq_band_high_limit = (13.) freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) con = spectral_connectivity_time( - data, freqs, method=method, mode=mode, sfreq=sfreq, + data, freqs, indices=indices, method=method, mode=mode, sfreq=sfreq, fmin=freq_band_low_limit, fmax=freq_band_high_limit, n_jobs=1, - faverage=True if method not in ['cacoh', 'mic'] else False, - average=True if method not in ['cacoh', 'mic'] else False, sm_times=0) + faverage=True if method != 'mic' else False, + average=True if method != 'mic' else False, sm_times=0) con_matrix = con.get_data() # MIC values can be pos. and neg., so must be averaged after taking the # absolute values for the test to work if method in multivar_methods: - if method in ['cacoh', 'mic']: + if method in ['mic']: con_matrix = np.mean(np.abs(con_matrix), axis=(0, 2)) assert con.shape == (n_epochs, 1, len(con.freqs)) else: @@ -1320,8 +1328,8 @@ def test_multivar_save_load(tmp_path): ragged_indices = (np.array([[0, 1]]), np.array([[2]])) for indices in [non_ragged_indices, ragged_indices]: con = spectral_connectivity_epochs( - epochs, method=['cacoh', 'mic', 'mim', 'gc', 'gc_tr'], indices=indices, - sfreq=sfreq, fmin=10, fmax=30) + epochs, method=['cacoh', 'mic', 'mim', 'gc', 'gc_tr'], + indices=indices, sfreq=sfreq, fmin=10, fmax=30) for this_con in con: this_con.save(tmp_file) read_con = read_connectivity(tmp_file) From e93fb47b3c5bb20164d5bb347b40692a52cfe430 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 11 Dec 2023 14:24:26 +0100 Subject: [PATCH 12/59] [MAINT] Black format changed code --- .../spectral/epochs_multivariate.py | 496 +++--- .../spectral/tests/test_spectral.py | 1439 +++++++++++------ 2 files changed, 1195 insertions(+), 740 deletions(-) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index 4957d001..5a92c7a2 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -34,35 +34,46 @@ def _check_rank_input(rank, data, indices): for group_i in range(2): # seeds and targets for con_i, con_idcs in enumerate(indices[group_i]): - s = np.linalg.svd(data_arr[:, con_idcs.compressed()], - compute_uv=False) + s = np.linalg.svd(data_arr[:, con_idcs.compressed()], compute_uv=False) rank[group_i][con_i] = np.min( - [np.count_nonzero(epoch >= epoch[0] * sv_tol) - for epoch in s]) + [np.count_nonzero(epoch >= epoch[0] * sv_tol) for epoch in s] + ) - logger.info('Estimated data ranks:') + logger.info("Estimated data ranks:") con_i = 1 for seed_rank, target_rank in zip(rank[0], rank[1]): - logger.info(' connection %i - seeds (%i); targets (%i)' - % (con_i, seed_rank, target_rank, )) + logger.info( + " connection %i - seeds (%i); targets (%i)" + % ( + con_i, + seed_rank, + target_rank, + ) + ) con_i += 1 rank = tuple((np.array(rank[0]), np.array(rank[1]))) else: if ( - len(rank) != 2 or len(rank[0]) != len(indices[0]) or - len(rank[1]) != len(indices[1]) + len(rank) != 2 + or len(rank[0]) != len(indices[0]) + or len(rank[1]) != len(indices[1]) ): - raise ValueError('rank argument must have shape (2, n_cons), ' - 'according to n_cons in the indices') + raise ValueError( + "rank argument must have shape (2, n_cons), " + "according to n_cons in the indices" + ) for seed_idcs, target_idcs, seed_rank, target_rank in zip( - indices[0], indices[1], rank[0], rank[1]): - if not (0 < seed_rank <= len(seed_idcs) and - 0 < target_rank <= len(target_idcs)): + indices[0], indices[1], rank[0], rank[1] + ): + if not ( + 0 < seed_rank <= len(seed_idcs) and 0 < target_rank <= len(target_idcs) + ): raise ValueError( - 'ranks for seeds and targets must be > 0 and <= the ' - 'number of channels in the seeds and targets, ' - 'respectively, for each connection') + "ranks for seeds and targets must be > 0 and <= the " + "number of channels in the seeds and targets, " + "respectively, for each connection" + ) return rank @@ -75,16 +86,16 @@ class _AbstractConEstBase(object): """ABC for connectivity estimators.""" def start_epoch(self): - raise NotImplementedError('start_epoch method not implemented') + raise NotImplementedError("start_epoch method not implemented") def accumulate(self, con_idx, csd_xy): - raise NotImplementedError('accumulate method not implemented') + raise NotImplementedError("accumulate method not implemented") def combine(self, other): - raise NotImplementedError('combine method not implemented') + raise NotImplementedError("combine method not implemented") def compute_con(self, con_idx, n_epochs): - raise NotImplementedError('compute_con method not implemented') + raise NotImplementedError("compute_con method not implemented") class _EpochMeanMultivariateConEstBase(_AbstractConEstBase): @@ -131,8 +142,14 @@ def _compute_n_progress_bar_steps(self): def _log_connection_number(self, con_i): """Log the number of the connection being computed.""" - logger.info('Computing %s for connection %i of %i' - % (self.name, con_i + 1, self.n_cons, )) + logger.info( + "Computing %s for connection %i of %i" + % ( + self.name, + con_i + 1, + self.n_cons, + ) + ) def _get_block_indices(self, block_i, limit): """Get indices for a computation block capped by a limit.""" @@ -143,13 +160,13 @@ def _get_block_indices(self, block_i, limit): def reshape_csd(self): """Reshape CSD into a matrix of times x freqs x signals x signals.""" if self.n_times == 0: - return (np.reshape(self._acc, ( - self.n_signals, self.n_signals, self.n_freqs, 1) - ).transpose(3, 2, 0, 1)) + return np.reshape( + self._acc, (self.n_signals, self.n_signals, self.n_freqs, 1) + ).transpose(3, 2, 0, 1) - return (np.reshape(self._acc, ( - self.n_signals, self.n_signals, self.n_freqs, self.n_times) - ).transpose(3, 2, 0, 1)) + return np.reshape( + self._acc, (self.n_signals, self.n_signals, self.n_freqs, self.n_times) + ).transpose(3, 2, 0, 1) class _MultivariateCohEstBase(_EpochMeanMultivariateConEstBase): @@ -165,27 +182,30 @@ class _MultivariateCohEstBase(_EpochMeanMultivariateConEstBase): def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): super(_MultivariateCohEstBase, self).__init__( - n_signals, n_cons, n_freqs, n_times, n_jobs) + n_signals, n_cons, n_freqs, n_times, n_jobs + ) def compute_con(self, indices, ranks, n_epochs=1): """Compute multivariate coherency methods.""" - assert self.name in ['CaCoh', 'MIC', 'MIM'], ( - 'the class name is not recognised, please contact the ' - 'mne-connectivity developers') + assert self.name in ["CaCoh", "MIC", "MIM"], ( + "the class name is not recognised, please contact the " + "mne-connectivity developers" + ) csd = self.reshape_csd() / n_epochs n_times = csd.shape[0] times = np.arange(n_times) freqs = np.arange(self.n_freqs) - if self.name in ['CaCoh', 'MIC']: + if self.name in ["CaCoh", "MIC"]: self.patterns = np.full( - (2, self.n_cons, indices[0].shape[1], self.n_freqs, n_times), - np.nan) + (2, self.n_cons, indices[0].shape[1], self.n_freqs, n_times), np.nan + ) con_i = 0 for seed_idcs, target_idcs, seed_rank, target_rank in zip( - indices[0], indices[1], ranks[0], ranks[1]): + indices[0], indices[1], ranks[0], ranks[1] + ): self._log_connection_number(con_i) seed_idcs = seed_idcs.compressed() @@ -196,10 +216,12 @@ def compute_con(self, indices, ranks, n_epochs=1): # Eqs. 32 & 33 of Ewald et al.; Eq. 15 of Vidaurre et al. C_bar, U_bar_aa, U_bar_bb = self._csd_svd( - C, seed_idcs, seed_rank, target_rank) + C, seed_idcs, seed_rank, target_rank + ) - self._compute_con_daughter(seed_idcs, target_idcs, C, C_bar, - U_bar_aa, U_bar_bb, con_i) + self._compute_con_daughter( + seed_idcs, target_idcs, C, C_bar, U_bar_aa, U_bar_bb, con_i + ) con_i += 1 @@ -222,33 +244,33 @@ def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): U_bar_aa = U_aa[..., :seed_rank] else: U_bar_aa = np.broadcast_to( - np.identity(n_seeds), - (n_times, self.n_freqs) + (n_seeds, n_seeds)) + np.identity(n_seeds), (n_times, self.n_freqs) + (n_seeds, n_seeds) + ) if target_rank != n_targets: U_bb = np.linalg.svd(np.real(C_bb), full_matrices=False)[0] U_bar_bb = U_bb[..., :target_rank] else: U_bar_bb = np.broadcast_to( - np.identity(n_targets), - (n_times, self.n_freqs) + (n_targets, n_targets)) + np.identity(n_targets), (n_times, self.n_freqs) + (n_targets, n_targets) + ) # Eq. 33 (Ewald et al.) - C_bar_aa = np.matmul( - U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_aa, U_bar_aa)) - C_bar_ab = np.matmul( - U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_ab, U_bar_bb)) - C_bar_bb = np.matmul( - U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_bb, U_bar_bb)) - C_bar_ba = np.matmul( - U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_ba, U_bar_aa)) - C_bar = np.append(np.append(C_bar_aa, C_bar_ab, axis=3), - np.append(C_bar_ba, C_bar_bb, axis=3), axis=2) + C_bar_aa = np.matmul(U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_aa, U_bar_aa)) + C_bar_ab = np.matmul(U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_ab, U_bar_bb)) + C_bar_bb = np.matmul(U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_bb, U_bar_bb)) + C_bar_ba = np.matmul(U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_ba, U_bar_aa)) + C_bar = np.append( + np.append(C_bar_aa, C_bar_ab, axis=3), + np.append(C_bar_ba, C_bar_bb, axis=3), + axis=2, + ) return C_bar, U_bar_aa, U_bar_bb - def _compute_con_daughter(self, seed_idcs, target_idcs, C, C_bar, U_bar_aa, - U_bar_bb, con_i): + def _compute_con_daughter( + self, seed_idcs, target_idcs, C, C_bar, U_bar_aa, U_bar_bb, con_i + ): """Compute multivariate coherency for one connection. An empty method to be implemented by subclasses. @@ -260,23 +282,24 @@ def _compute_t(self, C_r, n_seeds): Eq. 3 of Ewald et al.; part of Eq. 9 of Vidaurre et al. """ parallel, parallel_invsqrtm, _ = parallel_func( - _invsqrtm, self.n_jobs, verbose=False) + _invsqrtm, self.n_jobs, verbose=False + ) # imag. part of T filled when data is rank-deficient T = np.zeros(C_r.shape, dtype=np.complex128) - for block_i in ProgressBar(range(self.n_steps), - mesg="frequency blocks"): + for block_i in ProgressBar(range(self.n_steps), mesg="frequency blocks"): freqs = self._get_block_indices(block_i, self.n_freqs) - T[:, freqs] = np.array(parallel(parallel_invsqrtm( - C_r[:, f], T[:, f], n_seeds) for f in freqs) + T[:, freqs] = np.array( + parallel(parallel_invsqrtm(C_r[:, f], T[:, f], n_seeds) for f in freqs) ).transpose(1, 0, 2, 3) if not np.isreal(T).all() or not np.isfinite(T).all(): raise RuntimeError( - 'the transformation matrix of the data must be real-valued ' - 'and contain no NaN or infinity values; check that you are ' - 'using full rank data or specify an appropriate rank for the ' - 'seeds and targets that is less than or equal to their ranks') + "the transformation matrix of the data must be real-valued " + "and contain no NaN or infinity values; check that you are " + "using full rank data or specify an appropriate rank for the " + "seeds and targets that is less than or equal to their ranks" + ) return np.real(T) # make T real if check passes @@ -299,9 +322,11 @@ def _invsqrtm(C, T, n_seeds): """ for time_i in range(C.shape[0]): T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power( - C[time_i, :n_seeds, :n_seeds], -0.5) + C[time_i, :n_seeds, :n_seeds], -0.5 + ) T[time_i, n_seeds:, n_seeds:] = sp.linalg.fractional_matrix_power( - C[time_i, n_seeds:, n_seeds:], -0.5) + C[time_i, n_seeds:, n_seeds:], -0.5 + ) return T @@ -313,19 +338,20 @@ class _MultivariateImCohEstBase(_MultivariateCohEstBase): for equation references. """ - def _compute_con_daughter(self, seed_idcs, target_idcs, C, C_bar, U_bar_aa, - U_bar_bb, con_i): + def _compute_con_daughter( + self, seed_idcs, target_idcs, C, C_bar, U_bar_aa, U_bar_bb, con_i + ): """Compute multivariate imag. part of coherency for one connection.""" - assert self.name in ['MIC', 'MIM'], ( - 'the class name is not recognised, please contact the ' - 'mne-connectivity developers') + assert self.name in ["MIC", "MIM"], ( + "the class name is not recognised, please contact the " + "mne-connectivity developers" + ) # Eqs. 3 & 4 E = self._compute_e(C_bar, n_seeds=U_bar_aa.shape[3]) - if self.name == 'MIC': - self._compute_mic(E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, - con_i) + if self.name == "MIC": + self._compute_mic(E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, con_i) else: self._compute_mim(E, seed_idcs, target_idcs, con_i) @@ -342,8 +368,7 @@ def _compute_e(self, C, n_seeds): # E as imag. part of D between seeds and targets return np.imag(D[..., :n_seeds, n_seeds:]) - def _compute_mic(self, E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, - con_i): + def _compute_mic(self, E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, con_i): """Compute MIC & spatial patterns for one connection.""" n_seeds = len(seed_idcs) n_targets = len(target_idcs) @@ -352,13 +377,10 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, freqs = np.arange(self.n_freqs) # Eigendecomp. to find spatial filters for seeds and targets - w_seeds, V_seeds = np.linalg.eigh( - np.matmul(E, E.transpose(0, 1, 3, 2))) - w_targets, V_targets = np.linalg.eigh( - np.matmul(E.transpose(0, 1, 3, 2), E)) - if ( - len(seed_idcs) == len(target_idcs) and - np.all(np.sort(seed_idcs) == np.sort(target_idcs)) + w_seeds, V_seeds = np.linalg.eigh(np.matmul(E, E.transpose(0, 1, 3, 2))) + w_targets, V_targets = np.linalg.eigh(np.matmul(E.transpose(0, 1, 3, 2), E)) + if len(seed_idcs) == len(target_idcs) and np.all( + np.sort(seed_idcs) == np.sort(target_idcs) ): # strange edge-case where the eigenvectors returned should be a set # of identity matrices with one rotated by 90 degrees, but are @@ -372,8 +394,7 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, while not create_filter and not stop: for time_i in range(n_times): for freq_i in range(self.n_freqs): - if np.all(V_seeds[time_i, freq_i] == - V_targets[time_i, freq_i]): + if np.all(V_seeds[time_i, freq_i] == V_targets[time_i, freq_i]): create_filter = True break stop = True @@ -389,31 +410,40 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, beta = V_targets[times[:, None], freqs, :, w_targets.argmax(axis=2)] # Eq. 46 (seed spatial patterns) - self.patterns[0, con_i, :n_seeds] = (np.matmul( - np.real(C[..., :n_seeds, :n_seeds]), - np.matmul(U_bar_aa, np.expand_dims(alpha, axis=3))))[..., 0].T + self.patterns[0, con_i, :n_seeds] = ( + np.matmul( + np.real(C[..., :n_seeds, :n_seeds]), + np.matmul(U_bar_aa, np.expand_dims(alpha, axis=3)), + ) + )[..., 0].T # Eq. 47 (target spatial patterns) - self.patterns[1, con_i, :n_targets] = (np.matmul( - np.real(C[..., n_seeds:, n_seeds:]), - np.matmul(U_bar_bb, np.expand_dims(beta, axis=3))))[..., 0].T + self.patterns[1, con_i, :n_targets] = ( + np.matmul( + np.real(C[..., n_seeds:, n_seeds:]), + np.matmul(U_bar_bb, np.expand_dims(beta, axis=3)), + ) + )[..., 0].T # Eq. 7 - self.con_scores[con_i] = (np.einsum( - 'ijk,ijk->ij', alpha, np.matmul(E, np.expand_dims( - beta, axis=3))[..., 0] - ) / np.linalg.norm(alpha, axis=2) * np.linalg.norm(beta, axis=2)).T + self.con_scores[con_i] = ( + np.einsum( + "ijk,ijk->ij", alpha, np.matmul(E, np.expand_dims(beta, axis=3))[..., 0] + ) + / np.linalg.norm(alpha, axis=2) + * np.linalg.norm(beta, axis=2) + ).T def _compute_mim(self, E, seed_idcs, target_idcs, con_i): """Compute MIM (a.k.a. GIM if seeds == targets) for one connection.""" # Eq. 14 - self.con_scores[con_i] = np.matmul( - E, E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T + self.con_scores[con_i] = ( + np.matmul(E, E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T + ) # Eq. 15 - if ( - len(seed_idcs) == len(target_idcs) and - np.all(np.sort(seed_idcs) == np.sort(target_idcs)) + if len(seed_idcs) == len(target_idcs) and np.all( + np.sort(seed_idcs) == np.sort(target_idcs) ): self.con_scores[con_i] *= 0.5 @@ -421,13 +451,13 @@ def _compute_mim(self, E, seed_idcs, target_idcs, con_i): class _MICEst(_MultivariateImCohEstBase): """Multivariate imaginary part of coherency (MIC) estimator.""" - name = 'MIC' + name = "MIC" class _MIMEst(_MultivariateImCohEstBase): """Multivariate interaction measure (MIM) estimator.""" - name = 'MIM' + name = "MIM" class _CaCohEst(_MultivariateCohEstBase): @@ -437,14 +467,16 @@ class _CaCohEst(_MultivariateCohEstBase): 10.1016/j.neuroimage.2019.116009 for equation references. """ - name = 'CaCoh' + name = "CaCoh" - def _compute_con_daughter(self, seed_idcs, target_idcs, C, C_bar, U_bar_aa, - U_bar_bb, con_i): + def _compute_con_daughter( + self, seed_idcs, target_idcs, C, C_bar, U_bar_aa, U_bar_bb, con_i + ): """Compute CaCoh & spatial patterns for one connection.""" - assert self.name == 'CaCoh', ( - 'the class name is not recognised, please contact the ' - 'mne-connectivity developers') + assert self.name == "CaCoh", ( + "the class name is not recognised, please contact the " + "mne-connectivity developers" + ) n_seeds = len(seed_idcs) n_targets = len(target_idcs) @@ -459,13 +491,24 @@ def _compute_con_daughter(self, seed_idcs, target_idcs, C, C_bar, U_bar_aa, max_coh, max_phis = self._first_optimise_phi(C_bar_ab, T_aa, T_bb) - max_coh, max_phis = self._final_optimise_phi(C_bar_ab, T_aa, T_bb, - max_coh, max_phis) + max_coh, max_phis = self._final_optimise_phi( + C_bar_ab, T_aa, T_bb, max_coh, max_phis + ) self.con_scores[con_i] = max_coh.T - self._compute_patterns(max_phis, C, C_bar_ab, T_aa, T_bb, U_bar_aa, - U_bar_bb, n_seeds, n_targets, con_i) + self._compute_patterns( + max_phis, + C, + C_bar_ab, + T_aa, + T_bb, + U_bar_aa, + U_bar_bb, + n_seeds, + n_targets, + con_i, + ) def _first_optimise_phi(self, C_ab, T_aa, T_bb): """Find the rough angle, phi, at which coherence is maximised.""" @@ -499,13 +542,10 @@ def _final_optimise_phi(self, C_ab, T_aa, T_bb, max_coh, max_phis): for _ in range(n_iters): # 2nd order Taylor expansion around phi - coh_plus = self._compute_cacoh(max_phis + delta_phi, C_ab, T_aa, - T_bb) - coh_minus = self._compute_cacoh(max_phis - delta_phi, C_ab, T_aa, - T_bb) + coh_plus = self._compute_cacoh(max_phis + delta_phi, C_ab, T_aa, T_bb) + coh_minus = self._compute_cacoh(max_phis - delta_phi, C_ab, T_aa, T_bb) f_prime = (coh_plus - coh_minus) / (2 * delta_phi) - f_prime_prime = (coh_plus + coh_minus - 2 * max_coh) / ( - delta_phi ** 2) + f_prime_prime = (coh_plus + coh_minus - 2 * max_coh) / (delta_phi**2) # determine new phi to test phis = max_phis + (-f_prime / (f_prime_prime - mus)) @@ -540,18 +580,30 @@ def _compute_cacoh(self, phis, C_ab, T_aa, T_bb): b = np.linalg.eigh(np.matmul(D.transpose(0, 1, 3, 2), D))[1][..., -1] # Eq. 8 - numerator = np.einsum('ijk,ijk->ij', a, - np.matmul(D, np.expand_dims(b, axis=3))[..., 0]) - denominator = np.sqrt(np.einsum('ijk,ijk->ij', a, a) * - np.einsum('ijk,ijk->ij', b, b)) + numerator = np.einsum( + "ijk,ijk->ij", a, np.matmul(D, np.expand_dims(b, axis=3))[..., 0] + ) + denominator = np.sqrt( + np.einsum("ijk,ijk->ij", a, a) * np.einsum("ijk,ijk->ij", b, b) + ) return np.abs(numerator / denominator) - def _compute_patterns(self, phis, C, C_bar_ab, T_aa, T_bb, U_bar_aa, - U_bar_bb, n_seeds, n_targets, con_i): + def _compute_patterns( + self, + phis, + C, + C_bar_ab, + T_aa, + T_bb, + U_bar_aa, + U_bar_bb, + n_seeds, + n_targets, + con_i, + ): """Compute CaCoh spatial patterns for the optimised phi.""" - C_bar_ab = np.real(np.exp(-1j * np.expand_dims(phis, axis=(2, 3))) * - C_bar_ab) + C_bar_ab = np.real(np.exp(-1j * np.expand_dims(phis, axis=(2, 3))) * C_bar_ab) D = np.matmul(T_aa, np.matmul(C_bar_ab, T_bb)) a = np.linalg.eigh(np.matmul(D, D.transpose(0, 1, 3, 2)))[1][..., -1] b = np.linalg.eigh(np.matmul(D.transpose(0, 1, 3, 2), D))[1][..., -1] @@ -562,12 +614,12 @@ def _compute_patterns(self, phis, C, C_bar_ab, T_aa, T_bb, U_bar_aa, # Eq. 14; U_bar inclusion follows Eqs. 46 & 47 of Ewald et al. (2012) # seed spatial patterns - self.patterns[0, con_i, :n_seeds] = (np.matmul( - np.real(C[..., :n_seeds, :n_seeds]), np.matmul(U_bar_aa, alpha)) + self.patterns[0, con_i, :n_seeds] = ( + np.matmul(np.real(C[..., :n_seeds, :n_seeds]), np.matmul(U_bar_aa, alpha)) )[..., 0].T # target spatial patterns - self.patterns[1, con_i, :n_targets] = (np.matmul( - np.real(C[..., n_seeds:, n_seeds:]), np.matmul(U_bar_bb, beta)) + self.patterns[1, con_i, :n_targets] = ( + np.matmul(np.real(C[..., n_seeds:, n_seeds:]), np.matmul(U_bar_bb, beta)) )[..., 0].T @@ -577,21 +629,26 @@ class _GCEstBase(_EpochMeanMultivariateConEstBase): accumulate_psd = False def __init__(self, n_signals, n_cons, n_freqs, n_times, n_lags, n_jobs=1): - super(_GCEstBase, self).__init__( - n_signals, n_cons, n_freqs, n_times, n_jobs) + super(_GCEstBase, self).__init__(n_signals, n_cons, n_freqs, n_times, n_jobs) self.freq_res = (self.n_freqs - 1) * 2 if n_lags >= self.freq_res: raise ValueError( - 'the number of lags (%i) must be less than double the ' - 'frequency resolution (%i)' % (n_lags, self.freq_res, )) + "the number of lags (%i) must be less than double the " + "frequency resolution (%i)" + % ( + n_lags, + self.freq_res, + ) + ) self.n_lags = n_lags def compute_con(self, indices, ranks, n_epochs=1): """Compute multivariate state-space Granger causality.""" - assert self.name in ['GC', 'GC time-reversed'], ( - 'the class name is not recognised, please contact the ' - 'mne-connectivity developers') + assert self.name in ["GC", "GC time-reversed"], ( + "the class name is not recognised, please contact the " + "mne-connectivity developers" + ) csd = self.reshape_csd() / n_epochs @@ -601,7 +658,8 @@ def compute_con(self, indices, ranks, n_epochs=1): con_i = 0 for seed_idcs, target_idcs, seed_rank, target_rank in zip( - indices[0], indices[1], ranks[0], ranks[1]): + indices[0], indices[1], ranks[0], ranks[1] + ): self._log_connection_number(con_i) seed_idcs = seed_idcs.compressed() @@ -621,12 +679,13 @@ def compute_con(self, indices, ranks, n_epochs=1): A_f, V = self._autocov_to_full_var(autocov) A_f_3d = np.reshape( - A_f, (n_times, n_signals, n_signals * self.n_lags), - order="F") + A_f, (n_times, n_signals, n_signals * self.n_lags), order="F" + ) A, K = self._full_var_to_iss(A_f_3d) self.con_scores[con_i] = self._iss_to_ugc( - A, A_f_3d, K, V, con_seeds, con_targets) + A, A_f_3d, K, V, con_seeds, con_targets + ) con_i += 1 @@ -660,16 +719,15 @@ def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): C_bb = csd[..., n_seeds:, n_seeds:] C_ba = csd[..., n_seeds:, :n_seeds] - C_bar_aa = np.matmul( - U_bar_aa.transpose(1, 0), np.matmul(C_aa, U_bar_aa)) - C_bar_ab = np.matmul( - U_bar_aa.transpose(1, 0), np.matmul(C_ab, U_bar_bb)) - C_bar_bb = np.matmul( - U_bar_bb.transpose(1, 0), np.matmul(C_bb, U_bar_bb)) - C_bar_ba = np.matmul( - U_bar_bb.transpose(1, 0), np.matmul(C_ba, U_bar_aa)) - C_bar = np.append(np.append(C_bar_aa, C_bar_ab, axis=3), - np.append(C_bar_ba, C_bar_bb, axis=3), axis=2) + C_bar_aa = np.matmul(U_bar_aa.transpose(1, 0), np.matmul(C_aa, U_bar_aa)) + C_bar_ab = np.matmul(U_bar_aa.transpose(1, 0), np.matmul(C_ab, U_bar_bb)) + C_bar_bb = np.matmul(U_bar_bb.transpose(1, 0), np.matmul(C_bb, U_bar_bb)) + C_bar_ba = np.matmul(U_bar_bb.transpose(1, 0), np.matmul(C_ba, U_bar_aa)) + C_bar = np.append( + np.append(C_bar_aa, C_bar_ab, axis=3), + np.append(C_bar_ba, C_bar_bb, axis=3), + axis=2, + ) return C_bar @@ -679,28 +737,33 @@ def _compute_autocov(self, csd): n_signals = csd.shape[2] circular_shifted_csd = np.concatenate( - [np.flip(np.conj(csd[:, 1:]), axis=1), csd[:, :-1]], axis=1) - ifft_shifted_csd = self._block_ifft( - circular_shifted_csd, self.freq_res) + [np.flip(np.conj(csd[:, 1:]), axis=1), csd[:, :-1]], axis=1 + ) + ifft_shifted_csd = self._block_ifft(circular_shifted_csd, self.freq_res) lags_ifft_shifted_csd = np.reshape( - ifft_shifted_csd[:, :self.n_lags + 1], - (n_times, self.n_lags + 1, n_signals ** 2), order="F") + ifft_shifted_csd[:, : self.n_lags + 1], + (n_times, self.n_lags + 1, n_signals**2), + order="F", + ) signs = np.repeat([1], self.n_lags + 1).tolist() signs[1::2] = [x * -1 for x in signs[1::2]] sign_matrix = np.repeat( - np.tile(np.array(signs), (n_signals ** 2, 1))[np.newaxis], - n_times, axis=0).transpose(0, 2, 1) + np.tile(np.array(signs), (n_signals**2, 1))[np.newaxis], n_times, axis=0 + ).transpose(0, 2, 1) - return np.real(np.reshape( - sign_matrix * lags_ifft_shifted_csd, - (n_times, self.n_lags + 1, n_signals, n_signals), order="F")) + return np.real( + np.reshape( + sign_matrix * lags_ifft_shifted_csd, + (n_times, self.n_lags + 1, n_signals, n_signals), + order="F", + ) + ) def _block_ifft(self, csd, n_points): """Compute block iFFT with n points.""" shape = csd.shape - csd_3d = np.reshape( - csd, (shape[0], shape[1], shape[2] * shape[3]), order="F") + csd_3d = np.reshape(csd, (shape[0], shape[1], shape[2] * shape[3]), order="F") csd_ifft = np.fft.ifft(csd_3d, n=n_points, axis=1) @@ -710,24 +773,28 @@ def _autocov_to_full_var(self, autocov): """Compute full VAR model using Whittle's LWR recursion.""" if np.any(np.linalg.det(autocov) == 0): raise RuntimeError( - 'the autocovariance matrix is singular; check if your data is ' - 'rank deficient and specify an appropriate rank argument <= ' - 'the rank of the seeds and targets') + "the autocovariance matrix is singular; check if your data is " + "rank deficient and specify an appropriate rank argument <= " + "the rank of the seeds and targets" + ) A_f, V = self._whittle_lwr_recursion(autocov) if not np.isfinite(A_f).all(): - raise RuntimeError('at least one VAR model coefficient is ' - 'infinite or NaN; check the data you are using') + raise RuntimeError( + "at least one VAR model coefficient is " + "infinite or NaN; check the data you are using" + ) try: np.linalg.cholesky(V) except np.linalg.LinAlgError as np_error: raise RuntimeError( - 'the covariance matrix of the residuals is not ' - 'positive-definite; check the singular values of your data ' - 'and specify an appropriate rank argument <= the rank of the ' - 'seeds and targets') from np_error + "the covariance matrix of the residuals is not " + "positive-definite; check the singular values of your data " + "and specify an appropriate rank argument <= the rank of the " + "seeds and targets" + ) from np_error return A_f, V @@ -744,11 +811,13 @@ def _whittle_lwr_recursion(self, G): cov = G[:, 0, :, :] # covariance G_f = np.reshape( - G[:, 1:, :, :].transpose(0, 3, 1, 2), (t, qn, n), - order="F") # forward autocov + G[:, 1:, :, :].transpose(0, 3, 1, 2), (t, qn, n), order="F" + ) # forward autocov G_b = np.reshape( - np.flip(G[:, 1:, :, :], 1).transpose(0, 3, 2, 1), (t, n, qn), - order="F").transpose(0, 2, 1) # backward autocov + np.flip(G[:, 1:, :, :], 1).transpose(0, 3, 2, 1), (t, n, qn), order="F" + ).transpose( + 0, 2, 1 + ) # backward autocov A_f = np.zeros((t, n, qn)) # forward coefficients A_b = np.zeros((t, n, qn)) # backward coefficients @@ -760,23 +829,29 @@ def _whittle_lwr_recursion(self, G): try: A_f[:, :, k_f] = np.linalg.solve( - cov, G_b[:, k_b, :].transpose(0, 2, 1)).transpose(0, 2, 1) + cov, G_b[:, k_b, :].transpose(0, 2, 1) + ).transpose(0, 2, 1) A_b[:, :, k_b] = np.linalg.solve( - cov, G_f[:, k_f, :].transpose(0, 2, 1)).transpose(0, 2, 1) + cov, G_f[:, k_f, :].transpose(0, 2, 1) + ).transpose(0, 2, 1) # Perform recursion for k in np.arange(2, q + 1): - var_A = (G_b[:, (r - 1) * n: r * n, :] - - np.matmul(A_f[:, :, k_f], G_b[:, k_b, :])) + var_A = G_b[:, (r - 1) * n : r * n, :] - np.matmul( + A_f[:, :, k_f], G_b[:, k_b, :] + ) var_B = cov - np.matmul(A_b[:, :, k_b], G_b[:, k_b, :]) - AA_f = np.linalg.solve( - var_B, var_A.transpose(0, 2, 1)).transpose(0, 2, 1) + AA_f = np.linalg.solve(var_B, var_A.transpose(0, 2, 1)).transpose( + 0, 2, 1 + ) - var_A = (G_f[:, (k - 1) * n: k * n, :] - - np.matmul(A_b[:, :, k_b], G_f[:, k_f, :])) + var_A = G_f[:, (k - 1) * n : k * n, :] - np.matmul( + A_b[:, :, k_b], G_f[:, k_f, :] + ) var_B = cov - np.matmul(A_f[:, :, k_f], G_f[:, k_f, :]) - AA_b = np.linalg.solve( - var_B, var_A.transpose(0, 2, 1)).transpose(0, 2, 1) + AA_b = np.linalg.solve(var_B, var_A.transpose(0, 2, 1)).transpose( + 0, 2, 1 + ) A_f_previous = A_f[:, :, k_f] A_b_previous = A_b[:, :, k_b] @@ -786,14 +861,17 @@ def _whittle_lwr_recursion(self, G): k_b = np.arange(r * n, qn) A_f[:, :, k_f] = np.dstack( - (A_f_previous - np.matmul(AA_f, A_b_previous), AA_f)) + (A_f_previous - np.matmul(AA_f, A_b_previous), AA_f) + ) A_b[:, :, k_b] = np.dstack( - (AA_b, A_b_previous - np.matmul(AA_b, A_f_previous))) + (AA_b, A_b_previous - np.matmul(AA_b, A_f_previous)) + ) except np.linalg.LinAlgError as np_error: raise RuntimeError( - 'the autocovariance matrix is singular; check if your data is ' - 'rank deficient and specify an appropriate rank argument <= ' - 'the rank of the seeds and targets') from np_error + "the autocovariance matrix is singular; check if your data is " + "rank deficient and specify an appropriate rank argument <= " + "the rank of the seeds and targets" + ) from np_error V = cov - np.matmul(A_f, G_f) A_f = np.reshape(A_f, (t, n, n, q), order="F") @@ -817,9 +895,12 @@ def _full_var_to_iss(self, A_f): I_p = np.dstack(t * [np.eye(m * p)]).transpose(2, 0, 1) A = np.hstack((A_f, I_p[:, : (m * p - m), :])) # state transition # matrix - K = np.hstack(( - np.dstack(t * [np.eye(m)]).transpose(2, 0, 1), - np.zeros((t, (m * (p - 1)), m)))) # Kalman gain matrix + K = np.hstack( + ( + np.dstack(t * [np.eye(m)]).transpose(2, 0, 1), + np.zeros((t, (m * (p - 1)), m)), + ) + ) # Kalman gain matrix return A, K @@ -843,8 +924,7 @@ def _iss_to_ugc(self, A, C, K, V, seeds, targets): HVH = np.matmul(HV_12, HV_12.conj().transpose(0, 1, 3, 2)) # Eq. 11 - return np.real( - np.log(np.linalg.det(S_11)) - np.log(np.linalg.det(S_11 - HVH))) + return np.real(np.log(np.linalg.det(S_11)) - np.log(np.linalg.det(S_11 - HVH))) def _iss_to_tf(self, A, C, K, z): """Compute transfer function for innovations-form state-space params. @@ -873,12 +953,11 @@ def _iss_to_tf(self, A, C, K, z): _gc_compute_H, self.n_jobs, verbose=False ) H = np.zeros((h, t, n, n), dtype=np.complex128) - for block_i in ProgressBar( - range(self.n_steps), mesg="frequency blocks" - ): + for block_i in ProgressBar(range(self.n_steps), mesg="frequency blocks"): freqs = self._get_block_indices(block_i, self.n_freqs) H[freqs] = parallel( - parallel_compute_H(A, C, K, z[k], I_n, I_m) for k in freqs) + parallel_compute_H(A, C, K, z[k], I_n, I_m) for k in freqs + ) return H @@ -914,10 +993,12 @@ def _gc_compute_H(A, C, K, z_k, I_n, I_m): 10.1103/PhysRevE.91.040101, Eq. 4. """ from scipy import linalg # XXX: is this necessary??? + H = np.zeros((A.shape[0], C.shape[1], C.shape[1]), dtype=np.complex128) for t in range(A.shape[0]): H[t] = I_n + np.matmul( - C[t], linalg.lu_solve(linalg.lu_factor(z_k * I_m - A[t]), K[t])) + C[t], linalg.lu_solve(linalg.lu_factor(z_k * I_m - A[t]), K[t]) + ) return H @@ -936,9 +1017,12 @@ class _GCTREst(_GCEstBase): # map names to estimator types _CON_METHOD_MAP_MULTIVARIATE = { - 'cacoh': _CaCohEst, 'mic': _MICEst, 'mim': _MIMEst, 'gc': _GCEst, - 'gc_tr': _GCTREst + "cacoh": _CaCohEst, + "mic": _MICEst, + "mim": _MIMEst, + "gc": _GCEst, + "gc_tr": _GCTREst, } -_multivariate_methods = ['cacoh', 'mic', 'mim', 'gc', 'gc_tr'] -_gc_methods = ['gc', 'gc_tr'] +_multivariate_methods = ["cacoh", "mic", "mim", "gc", "gc_tr"] +_gc_methods = ["gc", "gc_tr"] diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 88545e0b..4b2a398a 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -1,23 +1,28 @@ import os import numpy as np -from numpy.testing import (assert_allclose, assert_array_almost_equal, - assert_array_less) +from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_less import pandas as pd import pytest -from mne import (EpochsArray, SourceEstimate, create_info) +from mne import EpochsArray, SourceEstimate, create_info from mne.filter import filter_data from mne_connectivity import ( - SpectralConnectivity, spectral_connectivity_epochs, - read_connectivity, spectral_connectivity_time) -from mne_connectivity.spectral.epochs import (_get_n_epochs, - _compute_freq_mask, - _compute_freqs) + SpectralConnectivity, + spectral_connectivity_epochs, + read_connectivity, + spectral_connectivity_time, +) +from mne_connectivity.spectral.epochs import ( + _get_n_epochs, + _compute_freq_mask, + _compute_freqs, +) from mne_connectivity.spectral.epochs_bivariate import _CohEst -def create_test_dataset(sfreq, n_signals, n_epochs, n_times, tmin, tmax, - fstart, fend, trans_bandwidth=2.): +def create_test_dataset( + sfreq, n_signals, n_epochs, n_times, tmin, tmax, fstart, fend, trans_bandwidth=2.0 +): """Create test dataset with no spurious correlations. Parameters @@ -58,10 +63,16 @@ def create_test_dataset(sfreq, n_signals, n_epochs, n_times, tmin, tmax, times_data = np.linspace(tmin, tmax, n_times) # simulate connectivity from fstart to fend - data[1, :] = filter_data(data[0, :], sfreq, fstart, fend, - filter_length='auto', fir_design='firwin2', - l_trans_bandwidth=trans_bandwidth, - h_trans_bandwidth=trans_bandwidth) + data[1, :] = filter_data( + data[0, :], + sfreq, + fstart, + fend, + filter_length="auto", + fir_design="firwin2", + l_trans_bandwidth=trans_bandwidth, + h_trans_bandwidth=trans_bandwidth, + ) # add some noise, so the spectrum is not exactly zero data[1, :] += 1e-2 * rng.randn(n_times * n_epochs) data = data.reshape(n_signals, n_epochs, n_times) @@ -74,27 +85,29 @@ def _stc_gen(data, sfreq, tmin, combo=False): vertices = [np.arange(data.shape[1]), np.empty(0)] for d in data: if not combo: - stc = SourceEstimate(data=d, vertices=vertices, - tmin=tmin, tstep=1 / float(sfreq)) + stc = SourceEstimate( + data=d, vertices=vertices, tmin=tmin, tstep=1 / float(sfreq) + ) yield stc else: # simulate a combination of array and source estimate arr = d[0] - stc = SourceEstimate(data=d[1:], vertices=vertices, - tmin=tmin, tstep=1 / float(sfreq)) + stc = SourceEstimate( + data=d[1:], vertices=vertices, tmin=tmin, tstep=1 / float(sfreq) + ) yield (arr, stc) -@pytest.mark.parametrize('method', ['coh', 'cohy', 'imcoh', 'plv']) -@pytest.mark.parametrize('mode', ['multitaper', 'fourier', 'cwt_morlet']) +@pytest.mark.parametrize("method", ["coh", "cohy", "imcoh", "plv"]) +@pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) def test_spectral_connectivity_parallel(method, mode, tmp_path): """Test saving spectral connectivity with parallel functions.""" # Use a case known to have no spurious correlations (it would bad if # tests could randomly fail): rng = np.random.RandomState(0) - trans_bandwidth = 2. + trans_bandwidth = 2.0 - sfreq = 50. + sfreq = 50.0 n_signals = 3 n_epochs = 8 n_times = 256 @@ -103,10 +116,16 @@ def test_spectral_connectivity_parallel(method, mode, tmp_path): data = rng.randn(n_signals, n_epochs * n_times) # simulate connectivity from 5Hz..15Hz fstart, fend = 5.0, 15.0 - data[1, :] = filter_data(data[0, :], sfreq, fstart, fend, - filter_length='auto', fir_design='firwin2', - l_trans_bandwidth=trans_bandwidth, - h_trans_bandwidth=trans_bandwidth) + data[1, :] = filter_data( + data[0, :], + sfreq, + fstart, + fend, + filter_length="auto", + fir_design="firwin2", + l_trans_bandwidth=trans_bandwidth, + h_trans_bandwidth=trans_bandwidth, + ) # add some noise, so the spectrum is not exactly zero data[1, :] += 1e-2 * rng.randn(n_times * n_epochs) data = data.reshape(n_signals, n_epochs, n_times) @@ -115,122 +134,162 @@ def test_spectral_connectivity_parallel(method, mode, tmp_path): # define some frequencies for cwt cwt_freqs = np.arange(3, 24.5, 1) - if method == 'coh' and mode == 'multitaper': + if method == "coh" and mode == "multitaper": # only check adaptive estimation for coh to reduce test time check_adaptive = [False, True] else: check_adaptive = [False] - if method == 'coh' and mode == 'cwt_morlet': + if method == "coh" and mode == "cwt_morlet": # so we also test using an array for num cycles - cwt_n_cycles = 7. * np.ones(len(cwt_freqs)) + cwt_n_cycles = 7.0 * np.ones(len(cwt_freqs)) else: - cwt_n_cycles = 7. + cwt_n_cycles = 7.0 for adaptive in check_adaptive: - if adaptive: - mt_bandwidth = 1. + mt_bandwidth = 1.0 else: mt_bandwidth = None con = spectral_connectivity_epochs( - data, method=method, mode=mode, indices=None, sfreq=sfreq, - mt_adaptive=adaptive, mt_low_bias=True, - mt_bandwidth=mt_bandwidth, cwt_freqs=cwt_freqs, - cwt_n_cycles=cwt_n_cycles, n_jobs=n_jobs) - - tmp_file = tmp_path / 'temp_file.nc' + data, + method=method, + mode=mode, + indices=None, + sfreq=sfreq, + mt_adaptive=adaptive, + mt_low_bias=True, + mt_bandwidth=mt_bandwidth, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + n_jobs=n_jobs, + ) + + tmp_file = tmp_path / "temp_file.nc" con.save(tmp_file) read_con = read_connectivity(tmp_file) assert_array_almost_equal(con.get_data(), read_con.get_data()) # split `repr` before the file size (`~23 kB` for example) - a = repr(con).split('~')[0] - b = repr(read_con).split('~')[0] + a = repr(con).split("~")[0] + b = repr(read_con).split("~")[0] assert a == b -@pytest.mark.parametrize('method', ['coh', 'cohy', 'imcoh', 'plv', - ['ciplv', 'ppc', 'pli', 'pli2_unbiased', - 'dpli', 'wpli', 'wpli2_debiased', 'coh']]) -@pytest.mark.parametrize('mode', ['multitaper', 'fourier', 'cwt_morlet']) +@pytest.mark.parametrize( + "method", + [ + "coh", + "cohy", + "imcoh", + "plv", + [ + "ciplv", + "ppc", + "pli", + "pli2_unbiased", + "dpli", + "wpli", + "wpli2_debiased", + "coh", + ], + ], +) +@pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) def test_spectral_connectivity(method, mode): """Test frequency-domain connectivity methods.""" - sfreq = 50. + sfreq = 50.0 n_signals = 3 n_epochs = 8 n_times = 256 - trans_bandwidth = 2. - tmin = 0. + trans_bandwidth = 2.0 + tmin = 0.0 tmax = (n_times - 1) / sfreq # 5Hz..15Hz fstart, fend = 5.0, 15.0 data, times_data = create_test_dataset( - sfreq, n_signals=n_signals, n_epochs=n_epochs, n_times=n_times, - tmin=tmin, tmax=tmax, - fstart=fstart, fend=fend, trans_bandwidth=trans_bandwidth) + sfreq, + n_signals=n_signals, + n_epochs=n_epochs, + n_times=n_times, + tmin=tmin, + tmax=tmax, + fstart=fstart, + fend=fend, + trans_bandwidth=trans_bandwidth, + ) # First we test some invalid parameters: - pytest.raises(ValueError, spectral_connectivity_epochs, - data, method='notamethod') - pytest.raises(ValueError, spectral_connectivity_epochs, data, - mode='notamode') + pytest.raises(ValueError, spectral_connectivity_epochs, data, method="notamethod") + pytest.raises(ValueError, spectral_connectivity_epochs, data, mode="notamode") # test invalid fmin fmax settings - pytest.raises(ValueError, spectral_connectivity_epochs, data, fmin=10, - fmax=10 + 0.5 * (sfreq / float(n_times))) - pytest.raises(ValueError, spectral_connectivity_epochs, - data, fmin=10, fmax=5) - pytest.raises(ValueError, spectral_connectivity_epochs, data, fmin=(0, 11), - fmax=(5, 10)) - pytest.raises(ValueError, spectral_connectivity_epochs, data, fmin=(11,), - fmax=(12, 15)) + pytest.raises( + ValueError, + spectral_connectivity_epochs, + data, + fmin=10, + fmax=10 + 0.5 * (sfreq / float(n_times)), + ) + pytest.raises(ValueError, spectral_connectivity_epochs, data, fmin=10, fmax=5) + pytest.raises( + ValueError, spectral_connectivity_epochs, data, fmin=(0, 11), fmax=(5, 10) + ) + pytest.raises( + ValueError, spectral_connectivity_epochs, data, fmin=(11,), fmax=(12, 15) + ) # define some frequencies for cwt cwt_freqs = np.arange(3, 24.5, 1) - if method == 'coh' and mode == 'multitaper': + if method == "coh" and mode == "multitaper": # only check adaptive estimation for coh to reduce test time check_adaptive = [False, True] else: check_adaptive = [False] - if method == 'coh' and mode == 'cwt_morlet': + if method == "coh" and mode == "cwt_morlet": # so we also test using an array for num cycles - cwt_n_cycles = 7. * np.ones(len(cwt_freqs)) + cwt_n_cycles = 7.0 * np.ones(len(cwt_freqs)) else: - cwt_n_cycles = 7. + cwt_n_cycles = 7.0 for adaptive in check_adaptive: - if adaptive: - mt_bandwidth = 1. + mt_bandwidth = 1.0 else: mt_bandwidth = None con = spectral_connectivity_epochs( - data, method=method, mode=mode, indices=None, sfreq=sfreq, - mt_adaptive=adaptive, mt_low_bias=True, - mt_bandwidth=mt_bandwidth, cwt_freqs=cwt_freqs, - cwt_n_cycles=cwt_n_cycles) + data, + method=method, + mode=mode, + indices=None, + sfreq=sfreq, + mt_adaptive=adaptive, + mt_low_bias=True, + mt_bandwidth=mt_bandwidth, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + ) if isinstance(method, list): this_con = con[0] else: this_con = con - freqs = this_con.attrs.get('freqs_used') + freqs = this_con.attrs.get("freqs_used") n = this_con.n_epochs_used if isinstance(this_con, SpectralConnectivity): - times = this_con.attrs.get('times_used') + times = this_con.attrs.get("times_used") else: times = this_con.times - assert (n == n_epochs) + assert n == n_epochs assert_array_almost_equal(times_data, times) - if mode == 'multitaper': + if mode == "multitaper": upper_t = 0.95 lower_t = 0.5 else: # mode == 'fourier' or mode == 'cwt_morlet' @@ -240,54 +299,42 @@ def test_spectral_connectivity(method, mode): # test the simulated signal gidx = np.searchsorted(freqs, (fstart, fend)) - bidx = np.searchsorted(freqs, - (fstart - trans_bandwidth * 2, - fend + trans_bandwidth * 2)) - if method == 'coh': + bidx = np.searchsorted( + freqs, (fstart - trans_bandwidth * 2, fend + trans_bandwidth * 2) + ) + if method == "coh": assert np.all( - con.get_data(output='dense')[ - 1, 0, gidx[0]:gidx[1] - ] > upper_t), \ - con.get_data()[ - 1, 0, gidx[0]:gidx[1]].min() + con.get_data(output="dense")[1, 0, gidx[0] : gidx[1]] > upper_t + ), con.get_data()[1, 0, gidx[0] : gidx[1]].min() # we see something for zero-lag - assert_array_less( - con.get_data(output='dense') - [1, 0, :bidx[0]], - lower_t) + assert_array_less(con.get_data(output="dense")[1, 0, : bidx[0]], lower_t) assert np.all( - con.get_data(output='dense')[1, 0, bidx[1]:] < lower_t), \ - con.get_data()[1, 0, bidx[1:]].max() - elif method == 'cohy': + con.get_data(output="dense")[1, 0, bidx[1] :] < lower_t + ), con.get_data()[1, 0, bidx[1:]].max() + elif method == "cohy": # imaginary coh will be zero - check = np.imag(con.get_data(output='dense') - [1, 0, gidx[0]:gidx[1]]) + check = np.imag(con.get_data(output="dense")[1, 0, gidx[0] : gidx[1]]) assert np.all(check < lower_t), check.max() # we see something for zero-lag assert_array_less( - upper_t, - np.abs(con.get_data(output='dense')[ - 1, 0, gidx[0]:gidx[1] - ])) + upper_t, np.abs(con.get_data(output="dense")[1, 0, gidx[0] : gidx[1]]) + ) assert_array_less( - np.abs(con.get_data(output='dense')[1, 0, :bidx[0]]), - lower_t) + np.abs(con.get_data(output="dense")[1, 0, : bidx[0]]), lower_t + ) assert_array_less( - np.abs(con.get_data(output='dense')[1, 0, bidx[1]:]), - lower_t) - elif method == 'imcoh': + np.abs(con.get_data(output="dense")[1, 0, bidx[1] :]), lower_t + ) + elif method == "imcoh": # imaginary coh will be zero assert_array_less( - con.get_data(output='dense')[1, 0, gidx[0]:gidx[1]], - lower_t) - assert_array_less( - con.get_data(output='dense')[1, 0, :bidx[0]], - lower_t) - assert_array_less( - con.get_data(output='dense')[1, 0, bidx[1]:], lower_t), + con.get_data(output="dense")[1, 0, gidx[0] : gidx[1]], lower_t + ) + assert_array_less(con.get_data(output="dense")[1, 0, : bidx[0]], lower_t) + assert_array_less(con.get_data(output="dense")[1, 0, bidx[1] :], lower_t), assert np.all( - con.get_data(output='dense')[1, 0, bidx[1]:] < lower_t), \ - con.get_data()[1, 0, bidx[1]:].max() + con.get_data(output="dense")[1, 0, bidx[1] :] < lower_t + ), con.get_data()[1, 0, bidx[1] :].max() # compute a subset of connections using indices and 2 jobs indices = (np.array([2, 1]), np.array([0, 0])) @@ -299,21 +346,30 @@ def test_spectral_connectivity(method, mode): stc_data = _stc_gen(data, sfreq, tmin) con2 = spectral_connectivity_epochs( - stc_data, method=test_methods, mode=mode, indices=indices, - sfreq=sfreq, mt_adaptive=adaptive, mt_low_bias=True, - mt_bandwidth=mt_bandwidth, tmin=tmin, tmax=tmax, - cwt_freqs=cwt_freqs, cwt_n_cycles=cwt_n_cycles) + stc_data, + method=test_methods, + mode=mode, + indices=indices, + sfreq=sfreq, + mt_adaptive=adaptive, + mt_low_bias=True, + mt_bandwidth=mt_bandwidth, + tmin=tmin, + tmax=tmax, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + ) assert isinstance(con2, list) assert len(con2) == len(test_methods) - freqs2 = con2[0].attrs.get('freqs_used') - if 'times' in con2[0].dims: + freqs2 = con2[0].attrs.get("freqs_used") + if "times" in con2[0].dims: times2 = con2[0].times else: - times2 = con2[0].attrs.get('times_used') + times2 = con2[0].attrs.get("times_used") n2 = con2[0].n_epochs_used - if method == 'coh': + if method == "coh": assert_array_almost_equal(con2[0].get_data(), con2[1].get_data()) if not isinstance(method, list): @@ -324,38 +380,50 @@ def test_spectral_connectivity(method, mode): # "con2" is a raveled array already, so # simulate setting indices on the full output in "con" - assert_array_almost_equal(con.get_data(output='dense')[indices], - con2.get_data()) - assert (n == n2) + assert_array_almost_equal( + con.get_data(output="dense")[indices], con2.get_data() + ) + assert n == n2 assert_array_almost_equal(times_data, times2) else: # we get the same result for the probed connections - assert (len(con) == len(con2)) + assert len(con) == len(con2) for c, c2 in zip(con, con2): assert_array_almost_equal(freqs, freqs2) - assert_array_almost_equal(c.get_data(output='dense')[indices], - c2.get_data()) - assert (n == n2) + assert_array_almost_equal( + c.get_data(output="dense")[indices], c2.get_data() + ) + assert n == n2 assert_array_almost_equal(times_data, times2) # Test with faverage # compute same connections for two bands, fskip=1, and f. avg. - fmin = (5., 15.) - fmax = (15., 30.) + fmin = (5.0, 15.0) + fmax = (15.0, 30.0) con3 = spectral_connectivity_epochs( - data, method=method, mode=mode, indices=indices, - sfreq=sfreq, fmin=fmin, fmax=fmax, fskip=1, faverage=True, - mt_adaptive=adaptive, mt_low_bias=True, - mt_bandwidth=mt_bandwidth, cwt_freqs=cwt_freqs, - cwt_n_cycles=cwt_n_cycles) + data, + method=method, + mode=mode, + indices=indices, + sfreq=sfreq, + fmin=fmin, + fmax=fmax, + fskip=1, + faverage=True, + mt_adaptive=adaptive, + mt_low_bias=True, + mt_bandwidth=mt_bandwidth, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + ) if isinstance(method, list): - freqs3 = con3[0].attrs.get('freqs_used') + freqs3 = con3[0].attrs.get("freqs_used") else: - freqs3 = con3.attrs.get('freqs_used') + freqs3 = con3.attrs.get("freqs_used") - assert (isinstance(freqs3, list)) - assert (len(freqs3) == len(fmin)) + assert isinstance(freqs3, list) + assert len(freqs3) == len(fmin) for i in range(len(freqs3)): _fmin = max(fmin[i], min(cwt_freqs)) _fmax = min(fmax[i], max(cwt_freqs)) @@ -368,15 +436,14 @@ def test_spectral_connectivity(method, mode): for i in range(len(freqs3)): # now we want to get the frequency indices # create a frequency mask for all bands - n_times = len(con2.attrs.get('times_used')) + n_times = len(con2.attrs.get("times_used")) # compute frequencies to analyze based on number of samples, # sampling rate, specified wavelet frequencies and mode freqs = _compute_freqs(n_times, sfreq, cwt_freqs, mode) # compute the mask based on specified min/max and decim factor - freq_mask = _compute_freq_mask( - freqs, [fmin[i]], [fmax[i]], fskip) + freq_mask = _compute_freq_mask(freqs, [fmin[i]], [fmax[i]], fskip) freqs = freqs[freq_mask] freqs_idx = np.searchsorted(freqs2, freqs) con2_avg = np.mean(con2.get_data()[:, freqs_idx], axis=1) @@ -386,7 +453,7 @@ def test_spectral_connectivity(method, mode): for i in range(len(freqs3)): # now we want to get the frequency indices # create a frequency mask for all bands - n_times = len(con2[0].attrs.get('times_used')) + n_times = len(con2[0].attrs.get("times_used")) # compute frequencies to analyze based on number of # samples, sampling rate, specified wavelet frequencies @@ -395,29 +462,26 @@ def test_spectral_connectivity(method, mode): # compute the mask based on specified min/max and # decim factor - freq_mask = _compute_freq_mask( - freqs, [fmin[i]], [fmax[i]], fskip) + freq_mask = _compute_freq_mask(freqs, [fmin[i]], [fmax[i]], fskip) freqs = freqs[freq_mask] freqs_idx = np.searchsorted(freqs2, freqs) - con2_avg = np.mean(con2[j].get_data()[ - :, freqs_idx], axis=1) - assert_array_almost_equal( - con2_avg, con3[j].get_data()[:, i]) + con2_avg = np.mean(con2[j].get_data()[:, freqs_idx], axis=1) + assert_array_almost_equal(con2_avg, con3[j].get_data()[:, i]) # test _get_n_epochs full_list = list(range(10)) out_lens = np.array([len(x) for x in _get_n_epochs(full_list, 4)]) - assert ((out_lens == np.array([4, 4, 2])).all()) + assert (out_lens == np.array([4, 4, 2])).all() out_lens = np.array([len(x) for x in _get_n_epochs(full_list, 11)]) - assert (len(out_lens) > 0) - assert (out_lens[0] == 10) + assert len(out_lens) > 0 + assert out_lens[0] == 10 -@pytest.mark.parametrize('method', ['cacoh', 'mic', 'mim', 'gc']) +@pytest.mark.parametrize("method", ["cacoh", "mic", "mim", "gc"]) def test_spectral_connectivity_epochs_multivariate(method): """Test over-epoch multivariate connectivity methods.""" - mode = 'multitaper' # stick with single mode in interest of time + mode = "multitaper" # stick with single mode in interest of time sfreq = 100.0 # Hz n_signals = 4 # should be even! @@ -427,8 +491,10 @@ def test_spectral_connectivity_epochs_multivariate(method): trans_bandwidth = 2.0 # Hz delay = 10 # samples (non-zero delay needed for ImCoh and GC to be >> 0) - indices = (np.arange(n_seeds)[np.newaxis, :], - np.arange(n_seeds)[np.newaxis, :] + n_seeds) + indices = ( + np.arange(n_seeds)[np.newaxis, :], + np.arange(n_seeds)[np.newaxis, :] + n_seeds, + ) n_targets = n_seeds # 15-25 Hz connectivity @@ -437,127 +503,172 @@ def test_spectral_connectivity_epochs_multivariate(method): data = rng.randn(n_signals, n_epochs * n_times + delay) # simulate connectivity from fstart to fend data[n_seeds:, :] = filter_data( - data[:n_seeds, :], sfreq, fstart, fend, filter_length='auto', - fir_design='firwin2', l_trans_bandwidth=trans_bandwidth, - h_trans_bandwidth=trans_bandwidth) + data[:n_seeds, :], + sfreq, + fstart, + fend, + filter_length="auto", + fir_design="firwin2", + l_trans_bandwidth=trans_bandwidth, + h_trans_bandwidth=trans_bandwidth, + ) # add some noise, so the spectrum is not exactly zero data[n_seeds:, :] += 1e-2 * rng.randn(n_seeds, n_times * n_epochs + delay) # shift the seeds to that the targets are a delayed version of them - data[:n_seeds, :n_epochs * n_times] = data[:n_seeds, delay:] - data = data[:, :n_times * n_epochs] + data[:n_seeds, : n_epochs * n_times] = data[:n_seeds, delay:] + data = data[:, : n_times * n_epochs] data = data.reshape(n_signals, n_epochs, n_times) data = np.transpose(data, [1, 0, 2]) con = spectral_connectivity_epochs( - data, method=method, mode=mode, indices=indices, sfreq=sfreq, - gc_n_lags=20) + data, method=method, mode=mode, indices=indices, sfreq=sfreq, gc_n_lags=20 + ) freqs = con.freqs gidx = (freqs.index(fstart), freqs.index(fend) + 1) - bidx = (freqs.index(fstart - trans_bandwidth * 2), - freqs.index(fend + trans_bandwidth * 2) + 1) + bidx = ( + freqs.index(fstart - trans_bandwidth * 2), + freqs.index(fend + trans_bandwidth * 2) + 1, + ) - if method in ['cacoh', 'mic', 'mim']: + if method in ["cacoh", "mic", "mim"]: lower_t = 0.2 upper_t = 0.5 - assert np.abs(con.get_data())[0, gidx[0]:gidx[1]].mean() > upper_t - assert np.abs(con.get_data())[0, :bidx[0]].mean() < lower_t - assert np.abs(con.get_data())[0, bidx[1]:].mean() < lower_t + assert np.abs(con.get_data())[0, gidx[0] : gidx[1]].mean() > upper_t + assert np.abs(con.get_data())[0, : bidx[0]].mean() < lower_t + assert np.abs(con.get_data())[0, bidx[1] :].mean() < lower_t - elif method == 'gc': + elif method == "gc": lower_t = 0.2 upper_t = 0.8 - assert con.get_data()[0, gidx[0]:gidx[1]].mean() > upper_t - assert con.get_data()[0, :bidx[0]].mean() < lower_t - assert con.get_data()[0, bidx[1]:].mean() < lower_t + assert con.get_data()[0, gidx[0] : gidx[1]].mean() > upper_t + assert con.get_data()[0, : bidx[0]].mean() < lower_t + assert con.get_data()[0, bidx[1] :].mean() < lower_t # check that target -> seed connectivity is low indices_ts = (indices[1], indices[0]) con_ts = spectral_connectivity_epochs( - data, method=method, mode=mode, indices=indices_ts, sfreq=sfreq, - gc_n_lags=20) - assert con_ts.get_data()[0, gidx[0]:gidx[1]].mean() < lower_t + data, + method=method, + mode=mode, + indices=indices_ts, + sfreq=sfreq, + gc_n_lags=20, + ) + assert con_ts.get_data()[0, gidx[0] : gidx[1]].mean() < lower_t # check that TRGC is positive (i.e. net seed -> target connectivity not # due to noise) con_tr = spectral_connectivity_epochs( - data, method='gc_tr', mode=mode, indices=indices, sfreq=sfreq, - gc_n_lags=20) + data, method="gc_tr", mode=mode, indices=indices, sfreq=sfreq, gc_n_lags=20 + ) con_ts_tr = spectral_connectivity_epochs( - data, method='gc_tr', mode=mode, indices=indices_ts, sfreq=sfreq, - gc_n_lags=20) - trgc = ((con.get_data() - con_ts.get_data()) - - (con_tr.get_data() - con_ts_tr.get_data())) + data, + method="gc_tr", + mode=mode, + indices=indices_ts, + sfreq=sfreq, + gc_n_lags=20, + ) + trgc = (con.get_data() - con_ts.get_data()) - ( + con_tr.get_data() - con_ts_tr.get_data() + ) # checks that TRGC is positive and >> 0 (for 15-25 Hz) - assert np.all(trgc[0, gidx[0]:gidx[1]] > 0) - assert np.all(trgc[0, gidx[0]:gidx[1]] > upper_t) + assert np.all(trgc[0, gidx[0] : gidx[1]] > 0) + assert np.all(trgc[0, gidx[0] : gidx[1]] > upper_t) # checks that TRGC is ~ 0 for other frequencies - assert np.allclose(trgc[0, :bidx[0]].mean(), 0, atol=lower_t) - assert np.allclose(trgc[0, bidx[1]:].mean(), 0, atol=lower_t) + assert np.allclose(trgc[0, : bidx[0]].mean(), 0, atol=lower_t) + assert np.allclose(trgc[0, bidx[1] :].mean(), 0, atol=lower_t) # check all-to-all conn. computed for MIC/MIM when no indices given - if method in ['cacoh', 'mic', 'mim']: + if method in ["cacoh", "mic", "mim"]: con = spectral_connectivity_epochs( - data, method=method, mode=mode, indices=None, sfreq=sfreq) + data, method=method, mode=mode, indices=None, sfreq=sfreq + ) assert con.indices is None assert con.n_nodes == n_signals - if method in ['cacoh', 'mic']: - assert np.array(con.attrs['patterns']).shape[2] == n_signals + if method in ["cacoh", "mic"]: + assert np.array(con.attrs["patterns"]).shape[2] == n_signals # check ragged indices padded correctly ragged_indices = (np.array([[0]]), np.array([[1, 2]])) con = spectral_connectivity_epochs( - data, method=method, mode=mode, indices=ragged_indices, sfreq=sfreq) - assert np.all(np.array(con.indices) == - np.array([np.array([[0, -1]]), np.array([[1, 2]])])) + data, method=method, mode=mode, indices=ragged_indices, sfreq=sfreq + ) + assert np.all( + np.array(con.indices) == np.array([np.array([[0, -1]]), np.array([[1, 2]])]) + ) # check shape of MIC patterns - if method in ['cacoh', 'mic']: - for mode in ['multitaper', 'cwt_morlet']: + if method in ["cacoh", "mic"]: + for mode in ["multitaper", "cwt_morlet"]: con = spectral_connectivity_epochs( - data, method=method, mode=mode, indices=indices, sfreq=sfreq, - fmin=10, fmax=25, cwt_freqs=np.arange(10, 25), - faverage=True) + data, + method=method, + mode=mode, + indices=indices, + sfreq=sfreq, + fmin=10, + fmax=25, + cwt_freqs=np.arange(10, 25), + faverage=True, + ) - if mode == 'cwt_morlet': + if mode == "cwt_morlet": patterns_shape = ( (n_seeds, len(con.freqs), len(con.times)), - (n_targets, len(con.freqs), len(con.times))) + (n_targets, len(con.freqs), len(con.times)), + ) else: patterns_shape = ( (n_seeds, len(con.freqs)), - (n_targets, len(con.freqs))) + (n_targets, len(con.freqs)), + ) assert np.shape(con.attrs["patterns"][0][0]) == patterns_shape[0] assert np.shape(con.attrs["patterns"][1][0]) == patterns_shape[1] # only check these once for speed - if mode == 'multitaper': + if mode == "multitaper": # check patterns averaged over freqs - fmin = (5., 15.) - fmax = (15., 30.) + fmin = (5.0, 15.0) + fmax = (15.0, 30.0) con = spectral_connectivity_epochs( - data, method=method, mode=mode, indices=indices, - sfreq=sfreq, fmin=fmin, fmax=fmax, faverage=True) + data, + method=method, + mode=mode, + indices=indices, + sfreq=sfreq, + fmin=fmin, + fmax=fmax, + faverage=True, + ) assert np.shape(con.attrs["patterns"][0][0])[1] == len(fmin) assert np.shape(con.attrs["patterns"][1][0])[1] == len(fmin) # check patterns shape matches input data, not rank rank = (np.array([1]), np.array([1])) con = spectral_connectivity_epochs( - data, method=method, mode=mode, indices=indices, - sfreq=sfreq, rank=rank) - assert (np.shape(con.attrs["patterns"][0][0])[0] == n_seeds) - assert (np.shape(con.attrs["patterns"][1][0])[0] == n_targets) + data, + method=method, + mode=mode, + indices=indices, + sfreq=sfreq, + rank=rank, + ) + assert np.shape(con.attrs["patterns"][0][0])[0] == n_seeds + assert np.shape(con.attrs["patterns"][1][0])[0] == n_targets # check patterns padded correctly ragged_indices = (np.array([[0]]), np.array([[1, 2]])) con = spectral_connectivity_epochs( - data, method=method, mode=mode, indices=ragged_indices, - sfreq=sfreq) + data, method=method, mode=mode, indices=ragged_indices, sfreq=sfreq + ) patterns = np.array(con.attrs["patterns"]) patterns_shape = ( - (n_seeds, len(con.freqs)), (n_targets, len(con.freqs))) + (n_seeds, len(con.freqs)), + (n_targets, len(con.freqs)), + ) assert patterns[0, 0].shape == patterns_shape[0] assert patterns[1, 0].shape == patterns_shape[1] assert not np.any(np.isnan(patterns[0, 0, 0])) @@ -582,34 +693,46 @@ def test_multivariate_spectral_connectivity_epochs_regression(): computing the CSD or the final connectivity scores! """ fpath = os.path.dirname(os.path.realpath(__file__)) - data = pd.read_pickle( - os.path.join(fpath, 'data', 'example_multivariate_data.pkl')) + data = pd.read_pickle(os.path.join(fpath, "data", "example_multivariate_data.pkl")) sfreq = 100 indices = (np.array([[0, 1]]), np.array([[2, 3]])) - methods = ['cacoh', 'mic', 'mim', 'gc', 'gc_tr'] + methods = ["cacoh", "mic", "mim", "gc", "gc_tr"] con = spectral_connectivity_epochs( - data, method=methods, indices=indices, mode='multitaper', sfreq=sfreq, - fskip=0, faverage=False, tmin=0, tmax=None, mt_bandwidth=4, - mt_low_bias=True, mt_adaptive=False, gc_n_lags=20, - rank=tuple([[2], [2]]), n_jobs=1) + data, + method=methods, + indices=indices, + mode="multitaper", + sfreq=sfreq, + fskip=0, + faverage=False, + tmin=0, + tmax=None, + mt_bandwidth=4, + mt_low_bias=True, + mt_adaptive=False, + gc_n_lags=20, + rank=tuple([[2], [2]]), + n_jobs=1, + ) # should take the absolute of the MIC scores, as the MATLAB implementation # returns the absolute values - mne_results = {this_con.method: np.abs(this_con.get_data()) - for this_con in con} + mne_results = {this_con.method: np.abs(this_con.get_data()) for this_con in con} matlab_results = pd.read_pickle( - os.path.join(fpath, 'data', 'example_multivariate_matlab_results.pkl')) + os.path.join(fpath, "data", "example_multivariate_matlab_results.pkl") + ) for method in methods: assert_allclose(matlab_results[method], mne_results[method], 1e-5) @pytest.mark.parametrize( - 'method', ['cacoh', 'mic', 'mim', 'gc', 'gc_tr', - ['cacoh', 'mic', 'mim', 'gc', 'gc_tr']]) -@pytest.mark.parametrize('mode', ['multitaper', 'fourier', 'cwt_morlet']) + "method", + ["cacoh", "mic", "mim", "gc", "gc_tr", ["cacoh", "mic", "mim", "gc", "gc_tr"]], +) +@pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): """Test error catching for multivar. freq.-domain connectivity methods.""" - sfreq = 50. + sfreq = 50.0 n_signals = 4 # Do not change! n_epochs = 8 n_times = 256 @@ -619,55 +742,93 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): cwt_freqs = np.arange(10, 25 + 1) # check bad indices without nested array caught - with pytest.raises(TypeError, - match='multivariate indices must contain array-likes'): + with pytest.raises( + TypeError, match="multivariate indices must contain array-likes" + ): non_nested_indices = (np.array([0, 1]), np.array([2, 3])) spectral_connectivity_epochs( - data, method=method, mode=mode, indices=non_nested_indices, - sfreq=sfreq, gc_n_lags=10) + data, + method=method, + mode=mode, + indices=non_nested_indices, + sfreq=sfreq, + gc_n_lags=10, + ) # check bad indices with repeated channels caught - with pytest.raises(ValueError, - match='multivariate indices cannot contain repeated'): + with pytest.raises( + ValueError, match="multivariate indices cannot contain repeated" + ): repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) spectral_connectivity_epochs( - data, method=method, mode=mode, indices=repeated_indices, - sfreq=sfreq, gc_n_lags=10) + data, + method=method, + mode=mode, + indices=repeated_indices, + sfreq=sfreq, + gc_n_lags=10, + ) # check mixed methods caught - with pytest.raises(ValueError, - match='bivariate and multivariate connectivity'): + with pytest.raises(ValueError, match="bivariate and multivariate connectivity"): if isinstance(method, str): - mixed_methods = [method, 'coh'] + mixed_methods = [method, "coh"] elif isinstance(method, list): - mixed_methods = [*method, 'coh'] - spectral_connectivity_epochs(data, method=mixed_methods, mode=mode, - indices=indices, sfreq=sfreq, - cwt_freqs=cwt_freqs) + mixed_methods = [*method, "coh"] + spectral_connectivity_epochs( + data, + method=mixed_methods, + mode=mode, + indices=indices, + sfreq=sfreq, + cwt_freqs=cwt_freqs, + ) # check bad rank args caught too_low_rank = (np.array([0]), np.array([0])) - with pytest.raises(ValueError, - match='ranks for seeds and targets must be'): + with pytest.raises(ValueError, match="ranks for seeds and targets must be"): spectral_connectivity_epochs( - data, method=method, mode=mode, indices=indices, - sfreq=sfreq, rank=too_low_rank, cwt_freqs=cwt_freqs) + data, + method=method, + mode=mode, + indices=indices, + sfreq=sfreq, + rank=too_low_rank, + cwt_freqs=cwt_freqs, + ) too_high_rank = (np.array([3]), np.array([3])) - with pytest.raises(ValueError, - match='ranks for seeds and targets must be'): + with pytest.raises(ValueError, match="ranks for seeds and targets must be"): spectral_connectivity_epochs( - data, method=method, mode=mode, indices=indices, - sfreq=sfreq, rank=too_high_rank, cwt_freqs=cwt_freqs) + data, + method=method, + mode=mode, + indices=indices, + sfreq=sfreq, + rank=too_high_rank, + cwt_freqs=cwt_freqs, + ) too_few_rank = ([], []) - with pytest.raises(ValueError, match='rank argument must have shape'): + with pytest.raises(ValueError, match="rank argument must have shape"): spectral_connectivity_epochs( - data, method=method, mode=mode, indices=indices, - sfreq=sfreq, rank=too_few_rank, cwt_freqs=cwt_freqs) + data, + method=method, + mode=mode, + indices=indices, + sfreq=sfreq, + rank=too_few_rank, + cwt_freqs=cwt_freqs, + ) too_much_rank = (np.array([2, 2]), np.array([2, 2])) - with pytest.raises(ValueError, match='rank argument must have shape'): + with pytest.raises(ValueError, match="rank argument must have shape"): spectral_connectivity_epochs( - data, method=method, mode=mode, indices=indices, - sfreq=sfreq, rank=too_much_rank, cwt_freqs=cwt_freqs) + data, + method=method, + mode=mode, + indices=indices, + sfreq=sfreq, + rank=too_much_rank, + cwt_freqs=cwt_freqs, + ) # check rank-deficient data caught bad_data = data.copy() @@ -677,65 +838,102 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): assert np.all(np.linalg.matrix_rank(bad_data[:, (2, 3), :]) == 1) if isinstance(method, str): rank_con = spectral_connectivity_epochs( - bad_data, method=method, mode=mode, indices=indices, sfreq=sfreq, - gc_n_lags=10, cwt_freqs=cwt_freqs) + bad_data, + method=method, + mode=mode, + indices=indices, + sfreq=sfreq, + gc_n_lags=10, + cwt_freqs=cwt_freqs, + ) assert rank_con.attrs["rank"] == (np.array([1]), np.array([1])) - if method in ['cacoh', 'mic', 'mim']: + if method in ["cacoh", "mic", "mim"]: # check rank-deficient transformation matrix caught - with pytest.raises(RuntimeError, - match='the transformation matrix'): + with pytest.raises(RuntimeError, match="the transformation matrix"): spectral_connectivity_epochs( - bad_data, method=method, mode=mode, indices=indices, - sfreq=sfreq, rank=(np.array([2]), np.array([2])), - cwt_freqs=cwt_freqs) + bad_data, + method=method, + mode=mode, + indices=indices, + sfreq=sfreq, + rank=(np.array([2]), np.array([2])), + cwt_freqs=cwt_freqs, + ) # only check these once (e.g. only with multitaper) for speed - if method == 'gc' and mode == 'multitaper': + if method == "gc" and mode == "multitaper": # check bad n_lags caught frange = (5, 10) n_lags = 200 # will be far too high - with pytest.raises(ValueError, match='the number of lags'): + with pytest.raises(ValueError, match="the number of lags"): spectral_connectivity_epochs( - data, method=method, mode=mode, indices=indices, sfreq=sfreq, - fmin=frange[0], fmax=frange[1], gc_n_lags=n_lags, - cwt_freqs=cwt_freqs) + data, + method=method, + mode=mode, + indices=indices, + sfreq=sfreq, + fmin=frange[0], + fmax=frange[1], + gc_n_lags=n_lags, + cwt_freqs=cwt_freqs, + ) # check no indices caught - with pytest.raises(ValueError, match='indices must be specified'): - spectral_connectivity_epochs(data, method=method, mode=mode, - indices=None, sfreq=sfreq, - cwt_freqs=cwt_freqs) + with pytest.raises(ValueError, match="indices must be specified"): + spectral_connectivity_epochs( + data, + method=method, + mode=mode, + indices=None, + sfreq=sfreq, + cwt_freqs=cwt_freqs, + ) # check intersecting indices caught bad_indices = (np.array([[0, 1]]), np.array([[0, 2]])) - with pytest.raises(ValueError, - match='seed and target indices must not intersect'): - spectral_connectivity_epochs(data, method=method, mode=mode, - indices=bad_indices, sfreq=sfreq, - cwt_freqs=cwt_freqs) + with pytest.raises( + ValueError, match="seed and target indices must not intersect" + ): + spectral_connectivity_epochs( + data, + method=method, + mode=mode, + indices=bad_indices, + sfreq=sfreq, + cwt_freqs=cwt_freqs, + ) # check bad fmin/fmax caught - with pytest.raises(ValueError, - match='computing Granger causality on multiple'): - spectral_connectivity_epochs(data, method=method, mode=mode, - indices=indices, sfreq=sfreq, - fmin=(10., 15.), fmax=(15., 20.), - cwt_freqs=cwt_freqs) + with pytest.raises(ValueError, match="computing Granger causality on multiple"): + spectral_connectivity_epochs( + data, + method=method, + mode=mode, + indices=indices, + sfreq=sfreq, + fmin=(10.0, 15.0), + fmax=(15.0, 20.0), + cwt_freqs=cwt_freqs, + ) # check rank-deficient autocovariance caught - with pytest.raises(RuntimeError, - match='the autocovariance matrix is singular'): + with pytest.raises(RuntimeError, match="the autocovariance matrix is singular"): spectral_connectivity_epochs( - bad_data, method=method, mode=mode, indices=indices, - sfreq=sfreq, rank=(np.array([2]), np.array([2])), - cwt_freqs=cwt_freqs) + bad_data, + method=method, + mode=mode, + indices=indices, + sfreq=sfreq, + rank=(np.array([2]), np.array([2])), + cwt_freqs=cwt_freqs, + ) -@pytest.mark.parametrize('method', ['cacoh', 'mic', 'mim', 'gc', 'gc_tr']) +@pytest.mark.parametrize("method", ["cacoh", "mic", "mim", "gc", "gc_tr"]) def test_multivar_spectral_connectivity_parallel(method): """Test multivar. freq.-domain connectivity methods run in parallel.""" - sfreq = 50. + sfreq = 50.0 n_signals = 4 # Do not change! n_epochs = 8 n_times = 256 @@ -744,16 +942,29 @@ def test_multivar_spectral_connectivity_parallel(method): indices = (np.array([[0, 1]]), np.array([[2, 3]])) spectral_connectivity_epochs( - data, method=method, mode="multitaper", indices=indices, sfreq=sfreq, - gc_n_lags=10, n_jobs=2) + data, + method=method, + mode="multitaper", + indices=indices, + sfreq=sfreq, + gc_n_lags=10, + n_jobs=2, + ) spectral_connectivity_time( - data, freqs=np.arange(10, 25), method=method, mode="multitaper", - indices=indices, sfreq=sfreq, gc_n_lags=10, n_jobs=2) + data, + freqs=np.arange(10, 25), + method=method, + mode="multitaper", + indices=indices, + sfreq=sfreq, + gc_n_lags=10, + n_jobs=2, + ) def test_multivar_spectral_connectivity_flipped_indices(): """Test multivar. indices structure maintained by connectivity methods.""" - sfreq = 50. + sfreq = 50.0 n_signals = 4 n_epochs = 8 n_times = 256 @@ -763,109 +974,120 @@ def test_multivar_spectral_connectivity_flipped_indices(): # if we're not careful, when finding the channels we need to compute the # CSD for, we might accidentally reorder the connectivity indices - indices = (np.array([[0, 1]]), - np.array([[2, 3]])) - flipped_indices = (np.array([[2, 3]]), - np.array([[0, 1]])) - concat_indices = (np.array([[0, 1], [2, 3]]), - np.array([[2, 3], [0, 1]])) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) + flipped_indices = (np.array([[2, 3]]), np.array([[0, 1]])) + concat_indices = (np.array([[0, 1], [2, 3]]), np.array([[2, 3], [0, 1]])) # we test on GC since this is a directed connectivity measure - method = 'gc' + method = "gc" con_st = spectral_connectivity_epochs( # seed -> target - data, method=method, indices=indices, sfreq=sfreq, gc_n_lags=10) + data, method=method, indices=indices, sfreq=sfreq, gc_n_lags=10 + ) con_ts = spectral_connectivity_epochs( # target -> seed - data, method=method, indices=flipped_indices, sfreq=sfreq, - gc_n_lags=10) + data, method=method, indices=flipped_indices, sfreq=sfreq, gc_n_lags=10 + ) con_st_ts = spectral_connectivity_epochs( # seed -> target; target -> seed - data, method=method, indices=concat_indices, sfreq=sfreq, gc_n_lags=10) + data, method=method, indices=concat_indices, sfreq=sfreq, gc_n_lags=10 + ) assert not np.all(con_st.get_data() == con_ts.get_data()) assert np.all(con_st.get_data()[0] == con_st_ts.get_data()[0]) assert np.all(con_ts.get_data()[0] == con_st_ts.get_data()[1]) con_st = spectral_connectivity_time( # seed -> target - data, freqs, method=method, indices=indices, sfreq=sfreq, - gc_n_lags=10) + data, freqs, method=method, indices=indices, sfreq=sfreq, gc_n_lags=10 + ) con_ts = spectral_connectivity_time( # target -> seed - data, freqs, method=method, indices=flipped_indices, sfreq=sfreq, - gc_n_lags=10) + data, freqs, method=method, indices=flipped_indices, sfreq=sfreq, gc_n_lags=10 + ) con_st_ts = spectral_connectivity_time( # seed -> target; target -> seed - data, freqs, method=method, indices=concat_indices, sfreq=sfreq, - gc_n_lags=10) + data, freqs, method=method, indices=concat_indices, sfreq=sfreq, gc_n_lags=10 + ) assert not np.all(con_st.get_data() == con_ts.get_data()) assert np.all(con_st.get_data()[:, 0] == con_st_ts.get_data()[:, 0]) assert np.all(con_ts.get_data()[:, 0] == con_st_ts.get_data()[:, 1]) -@ pytest.mark.parametrize('kind', ('epochs', 'ndarray', 'stc', 'combo')) +@pytest.mark.parametrize("kind", ("epochs", "ndarray", "stc", "combo")) def test_epochs_tmin_tmax(kind): """Test spectral.spectral_connectivity_epochs with epochs and arrays.""" rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq, f = 10, 2, 2000, 1000., 20. + n_epochs, n_chs, n_times, sfreq, f = 10, 2, 2000, 1000.0, 20.0 data = rng.randn(n_epochs, n_chs, n_times) sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000) data[:, :, 500:1500] += sig - info = create_info(n_chs, sfreq, 'eeg') - if kind == 'epochs': + info = create_info(n_chs, sfreq, "eeg") + if kind == "epochs": tmin = -1 X = EpochsArray(data, info, tmin=tmin) - elif kind == 'stc': + elif kind == "stc": tmin = -1 - X = [SourceEstimate(d, [[0], [0]], tmin, 1. / sfreq) for d in data] - elif kind == 'combo': + X = [SourceEstimate(d, [[0], [0]], tmin, 1.0 / sfreq) for d in data] + elif kind == "combo": tmin = -1 - X = [(d[[0]], SourceEstimate(d[[1]], [[0], []], tmin, 1. / sfreq)) - for d in data] + X = [ + (d[[0]], SourceEstimate(d[[1]], [[0], []], tmin, 1.0 / sfreq)) for d in data + ] else: - assert kind == 'ndarray' + assert kind == "ndarray" tmin = 0 X = data want_times = np.arange(n_times) / sfreq + tmin # Parameters for computing connectivity fmin, fmax = f - 2, f + 2 - kwargs = {'method': 'coh', 'mode': 'multitaper', 'sfreq': sfreq, - 'fmin': fmin, 'fmax': fmax, 'faverage': True, - 'mt_adaptive': False, 'n_jobs': 1} + kwargs = { + "method": "coh", + "mode": "multitaper", + "sfreq": sfreq, + "fmin": fmin, + "fmax": fmax, + "faverage": True, + "mt_adaptive": False, + "n_jobs": 1, + } # Check the entire interval conn = spectral_connectivity_epochs(X, **kwargs) - assert 0.89 < conn.get_data(output='dense')[1, 0] < 0.91 - assert_allclose(conn.attrs.get('times_used'), want_times) + assert 0.89 < conn.get_data(output="dense")[1, 0] < 0.91 + assert_allclose(conn.attrs.get("times_used"), want_times) # Check a time interval before the sinusoid conn = spectral_connectivity_epochs(X, tmax=tmin + 0.5, **kwargs) - assert 0 < conn.get_data(output='dense')[1, 0] < 0.15 + assert 0 < conn.get_data(output="dense")[1, 0] < 0.15 # Check a time during the sinusoid - conn = spectral_connectivity_epochs( - X, tmin=tmin + 0.5, tmax=tmin + 1.5, **kwargs) - assert 0.93 < conn.get_data(output='dense')[1, 0] <= 0.94 + conn = spectral_connectivity_epochs(X, tmin=tmin + 0.5, tmax=tmin + 1.5, **kwargs) + assert 0.93 < conn.get_data(output="dense")[1, 0] <= 0.94 # Check a time interval after the sinusoid - conn = spectral_connectivity_epochs( - X, tmin=tmin + 1.5, tmax=tmin + 1.9, **kwargs) - assert 0 < conn.get_data(output='dense')[1, 0] < 0.15 + conn = spectral_connectivity_epochs(X, tmin=tmin + 1.5, tmax=tmin + 1.9, **kwargs) + assert 0 < conn.get_data(output="dense")[1, 0] < 0.15 # Check for warning if tmin, tmax is outside of the time limits of data - with pytest.warns(RuntimeWarning, match='start time tmin'): + with pytest.warns(RuntimeWarning, match="start time tmin"): spectral_connectivity_epochs(X, **kwargs, tmin=tmin - 0.1) - with pytest.warns(RuntimeWarning, match='stop time tmax'): + with pytest.warns(RuntimeWarning, match="stop time tmax"): spectral_connectivity_epochs(X, **kwargs, tmax=tmin + 2.5) # make one with mismatched times - if kind != 'combo': + if kind != "combo": return - X = [(SourceEstimate(d[[0]], [[0], []], tmin - 1, 1. / sfreq), - SourceEstimate(d[[1]], [[0], []], tmin, 1. / sfreq)) for d in data] - with pytest.warns(RuntimeWarning, match='time scales of input') as w: + X = [ + ( + SourceEstimate(d[[0]], [[0], []], tmin - 1, 1.0 / sfreq), + SourceEstimate(d[[1]], [[0], []], tmin, 1.0 / sfreq), + ) + for d in data + ] + with pytest.warns(RuntimeWarning, match="time scales of input") as w: spectral_connectivity_epochs(X, **kwargs) assert len(w) == 1 # just one even though there were multiple epochs @pytest.mark.parametrize( - 'method', ['coh', 'cacoh', 'mic', 'mim', 'plv', 'pli', 'wpli', 'ciplv']) -@pytest.mark.parametrize('mode', ['cwt_morlet', 'multitaper']) -@pytest.mark.parametrize('data_option', ['sync', 'random']) + "method", ["coh", "cacoh", "mic", "mim", "plv", "pli", "wpli", "ciplv"] +) +@pytest.mark.parametrize("mode", ["cwt_morlet", "multitaper"]) +@pytest.mark.parametrize("data_option", ["sync", "random"]) def test_spectral_connectivity_time_phaselocked(method, mode, data_option): """Test time-resolved spectral connectivity with simulated phase-locked data.""" @@ -875,10 +1097,10 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): n_times = 1000 sfreq = 250 data = np.zeros((n_epochs, n_channels, n_times)) - if data_option == 'random': + if data_option == "random": # Data is random, there should be no consistent phase differences. data = rng.random((n_epochs, n_channels, n_times)) - if data_option == 'sync': + if data_option == "sync": # Data consists of phase-locked 10Hz sine waves with constant phase # difference within each epoch. wave_freq = 10 @@ -886,14 +1108,16 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): for i in range(n_epochs): for c in range(n_channels): phase = rng.random() * 10 - x = np.linspace(-wave_freq * epoch_length * np.pi + phase, - wave_freq * epoch_length * np.pi + phase, - n_times) + x = np.linspace( + -wave_freq * epoch_length * np.pi + phase, + wave_freq * epoch_length * np.pi + phase, + n_times, + ) data[i, c] = np.squeeze(np.sin(x)) - multivar_methods = ['cacoh', 'mic', 'mim'] + multivar_methods = ["cacoh", "mic", "mim"] - if method == 'cacoh': + if method == "cacoh": # CaCoh within set of signals will always be 1, so need to specify # distinct seeds and targets indices = ([[0, 1]], [[2, 3]]) @@ -902,34 +1126,44 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): # the frequency band should contain the frequency at which there is a # hypothesized "connection" - freq_band_low_limit = (8.) - freq_band_high_limit = (13.) + freq_band_low_limit = 8.0 + freq_band_high_limit = 13.0 freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) con = spectral_connectivity_time( - data, freqs, indices=indices, method=method, mode=mode, sfreq=sfreq, - fmin=freq_band_low_limit, fmax=freq_band_high_limit, n_jobs=1, - faverage=True if method != 'mic' else False, - average=True if method != 'mic' else False, sm_times=0) + data, + freqs, + indices=indices, + method=method, + mode=mode, + sfreq=sfreq, + fmin=freq_band_low_limit, + fmax=freq_band_high_limit, + n_jobs=1, + faverage=True if method != "mic" else False, + average=True if method != "mic" else False, + sm_times=0, + ) con_matrix = con.get_data() # MIC values can be pos. and neg., so must be averaged after taking the # absolute values for the test to work if method in multivar_methods: - if method in ['mic']: + if method in ["mic"]: con_matrix = np.mean(np.abs(con_matrix), axis=(0, 2)) assert con.shape == (n_epochs, 1, len(con.freqs)) else: assert con.shape == (1, len(con.freqs)) else: - assert con.shape == (n_channels ** 2, len(con.freqs)) + assert con.shape == (n_channels**2, len(con.freqs)) con_matrix = np.reshape(con_matrix, (n_channels, n_channels))[ - np.tril_indices(n_channels, -1)] + np.tril_indices(n_channels, -1) + ] - if data_option == 'sync': + if data_option == "sync": # signals are perfectly phase-locked, connectivity matrix should be # a matrix of ones assert np.allclose(con_matrix, np.ones(con_matrix.shape), atol=0.01) - if data_option == 'random': + if data_option == "random": # signals are random, all connectivity values should be small # 0.5 is picked rather arbitrarily such that the obsolete wrong # implementation fails @@ -944,7 +1178,7 @@ def test_spectral_connectivity_time_delayed(): Granger scores only in the context of the noise-corrected TRGC metric, where the true directionality of the connections seems to identified. """ - mode = 'multitaper' # stick with single mode in interest of time + mode = "multitaper" # stick with single mode in interest of time sfreq = 100.0 # Hz n_signals = 4 # should be even! @@ -962,25 +1196,48 @@ def test_spectral_connectivity_time_delayed(): data = rng.randn(n_signals, n_epochs * n_times + delay) # simulate connectivity from fstart to fend data[n_seeds:, :] = filter_data( - data[:n_seeds, :], sfreq, fstart, fend, filter_length='auto', - fir_design='firwin2', l_trans_bandwidth=trans_bandwidth, - h_trans_bandwidth=trans_bandwidth) + data[:n_seeds, :], + sfreq, + fstart, + fend, + filter_length="auto", + fir_design="firwin2", + l_trans_bandwidth=trans_bandwidth, + h_trans_bandwidth=trans_bandwidth, + ) # add some noise, so the spectrum is not exactly zero data[n_seeds:, :] += 1e-2 * rng.randn(n_seeds, n_times * n_epochs + delay) # shift the seeds to that the targets are a delayed version of them - data[:n_seeds, :n_epochs * n_times] = data[:n_seeds, delay:] - data = data[:, :n_times * n_epochs] + data[:n_seeds, : n_epochs * n_times] = data[:n_seeds, delay:] + data = data[:, : n_times * n_epochs] data = data.reshape(n_signals, n_epochs, n_times) data = np.transpose(data, [1, 0, 2]) freqs = np.arange(2.5, 50, 0.5) con_st = spectral_connectivity_time( - data, freqs, method=['gc', 'gc_tr'], indices=indices, mode=mode, - sfreq=sfreq, n_jobs=1, gc_n_lags=20, n_cycles=5, average=True) + data, + freqs, + method=["gc", "gc_tr"], + indices=indices, + mode=mode, + sfreq=sfreq, + n_jobs=1, + gc_n_lags=20, + n_cycles=5, + average=True, + ) con_ts = spectral_connectivity_time( - data, freqs, method=['gc', 'gc_tr'], indices=(indices[1], indices[0]), - mode=mode, sfreq=sfreq, n_jobs=1, gc_n_lags=20, n_cycles=5, - average=True) + data, + freqs, + method=["gc", "gc_tr"], + indices=(indices[1], indices[0]), + mode=mode, + sfreq=sfreq, + n_jobs=1, + gc_n_lags=20, + n_cycles=5, + average=True, + ) st = con_st[0].get_data() st_tr = con_st[1].get_data() ts = con_ts[0].get_data() @@ -989,23 +1246,24 @@ def test_spectral_connectivity_time_delayed(): freqs = con_st[0].freqs gidx = (freqs.index(fstart), freqs.index(fend) + 1) - bidx = (freqs.index(fstart - trans_bandwidth * 2), - freqs.index(fend + trans_bandwidth * 2) + 1) + bidx = ( + freqs.index(fstart - trans_bandwidth * 2), + freqs.index(fend + trans_bandwidth * 2) + 1, + ) # assert that TRGC (i.e. net, noise-corrected connectivity) is positive and # >> 0 (i.e. that there is indeed a flow of info. from seeds to targets, # as simulated) - assert np.all(trgc[:, gidx[0]:gidx[1]] > 0) - assert trgc[:, gidx[0]:gidx[1]].mean() > 0.4 + assert np.all(trgc[:, gidx[0] : gidx[1]] > 0) + assert trgc[:, gidx[0] : gidx[1]].mean() > 0.4 # check that non-interacting freqs. have close to zero connectivity - assert np.allclose(trgc[0, :bidx[0]].mean(), 0, atol=0.1) - assert np.allclose(trgc[0, bidx[1]:].mean(), 0, atol=0.1) + assert np.allclose(trgc[0, : bidx[0]].mean(), 0, atol=0.1) + assert np.allclose(trgc[0, bidx[1] :].mean(), 0, atol=0.1) -@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli', 'ciplv']) -@pytest.mark.parametrize( - 'freqs', [[8., 10.], [8, 10], 10., 10]) -@pytest.mark.parametrize('mode', ['cwt_morlet', 'multitaper']) +@pytest.mark.parametrize("method", ["coh", "plv", "pli", "wpli", "ciplv"]) +@pytest.mark.parametrize("freqs", [[8.0, 10.0], [8, 10], 10.0, 10]) +@pytest.mark.parametrize("mode", ["cwt_morlet", "multitaper"]) def test_spectral_connectivity_time_freqs(method, freqs, mode): """Test time-resolved spectral connectivity with int and float values for freqs.""" @@ -1023,46 +1281,61 @@ def test_spectral_connectivity_time_freqs(method, freqs, mode): for i in range(n_epochs): for c in range(n_channels): phase = rng.random() * 10 - x = np.linspace(-wave_freq * epoch_length * np.pi + phase, - wave_freq * epoch_length * np.pi + phase, - n_times) + x = np.linspace( + -wave_freq * epoch_length * np.pi + phase, + wave_freq * epoch_length * np.pi + phase, + n_times, + ) data[i, c] = np.squeeze(np.sin(x)) # the frequency band should contain the frequency at which there is a # hypothesized "connection" - con = spectral_connectivity_time(data, freqs, method=method, - mode=mode, sfreq=sfreq, - fmin=np.min(freqs), - fmax=np.max(freqs), n_jobs=1, - faverage=True, average=True, sm_times=0) - assert con.shape == (n_channels ** 2, len(con.freqs)) - con_matrix = con.get_data('dense')[..., 0] + con = spectral_connectivity_time( + data, + freqs, + method=method, + mode=mode, + sfreq=sfreq, + fmin=np.min(freqs), + fmax=np.max(freqs), + n_jobs=1, + faverage=True, + average=True, + sm_times=0, + ) + assert con.shape == (n_channels**2, len(con.freqs)) + con_matrix = con.get_data("dense")[..., 0] # signals are perfectly phase-locked, connectivity matrix should be # a lower triangular matrix of ones - assert np.allclose(con_matrix, np.tril(np.ones(con_matrix.shape), k=-1), - atol=0.01) + assert np.allclose(con_matrix, np.tril(np.ones(con_matrix.shape), k=-1), atol=0.01) -@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) -@pytest.mark.parametrize( - 'mode', ['cwt_morlet', 'multitaper']) +@pytest.mark.parametrize("method", ["coh", "plv", "pli", "wpli"]) +@pytest.mark.parametrize("mode", ["cwt_morlet", "multitaper"]) def test_spectral_connectivity_time_resolved(method, mode): """Test time-resolved spectral connectivity.""" - sfreq = 50. + sfreq = 50.0 n_signals = 3 n_epochs = 2 n_times = 1000 - trans_bandwidth = 2. - tmin = 0. + trans_bandwidth = 2.0 + tmin = 0.0 tmax = (n_times - 1) / sfreq # 5Hz..15Hz fstart, fend = 5.0, 15.0 data, _ = create_test_dataset( - sfreq, n_signals=n_signals, n_epochs=n_epochs, n_times=n_times, - tmin=tmin, tmax=tmax, - fstart=fstart, fend=fend, trans_bandwidth=trans_bandwidth) + sfreq, + n_signals=n_signals, + n_epochs=n_epochs, + n_times=n_times, + tmin=tmin, + tmax=tmax, + fstart=fstart, + fend=fend, + trans_bandwidth=trans_bandwidth, + ) ch_names = np.arange(n_signals).astype(str).tolist() - info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg') + info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg") data = EpochsArray(data, info) # define some frequencies for tfr @@ -1070,47 +1343,58 @@ def test_spectral_connectivity_time_resolved(method, mode): # run connectivity estimation con = spectral_connectivity_time( - data, freqs, sfreq=sfreq, method=method, mode=mode, - n_cycles=5) - assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs)) - assert con.get_data(output='dense').shape == \ - (n_epochs, n_signals, n_signals, len(con.freqs)) + data, freqs, sfreq=sfreq, method=method, mode=mode, n_cycles=5 + ) + assert con.shape == (n_epochs, n_signals**2, len(con.freqs)) + assert con.get_data(output="dense").shape == ( + n_epochs, + n_signals, + n_signals, + len(con.freqs), + ) # test the simulated signal triu_inds = np.vstack(np.triu_indices(n_signals, k=1)).T # average over frequencies - conn_data = con.get_data(output='dense').mean(axis=-1) + conn_data = con.get_data(output="dense").mean(axis=-1) # the indices at which there is a correlation should be greater # then the rest of the components for epoch_idx in range(n_epochs): high_conn_val = conn_data[epoch_idx, 0, 1] - assert all(high_conn_val >= conn_data[epoch_idx, idx, jdx] - for idx, jdx in triu_inds) + assert all( + high_conn_val >= conn_data[epoch_idx, idx, jdx] for idx, jdx in triu_inds + ) -@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) -@pytest.mark.parametrize( - 'mode', ['cwt_morlet', 'multitaper']) -@pytest.mark.parametrize('padding', [0, 1, 5]) +@pytest.mark.parametrize("method", ["coh", "plv", "pli", "wpli"]) +@pytest.mark.parametrize("mode", ["cwt_morlet", "multitaper"]) +@pytest.mark.parametrize("padding", [0, 1, 5]) def test_spectral_connectivity_time_padding(method, mode, padding): """Test time-resolved spectral connectivity with padding.""" - sfreq = 50. + sfreq = 50.0 n_signals = 3 n_epochs = 2 n_times = 300 - trans_bandwidth = 2. - tmin = 0. + trans_bandwidth = 2.0 + tmin = 0.0 tmax = (n_times - 1) / sfreq # 5Hz..15Hz fstart, fend = 5.0, 15.0 data, _ = create_test_dataset( - sfreq, n_signals=n_signals, n_epochs=n_epochs, n_times=n_times, - tmin=tmin, tmax=tmax, - fstart=fstart, fend=fend, trans_bandwidth=trans_bandwidth) + sfreq, + n_signals=n_signals, + n_epochs=n_epochs, + n_times=n_times, + tmin=tmin, + tmax=tmax, + fstart=fstart, + fend=fend, + trans_bandwidth=trans_bandwidth, + ) ch_names = np.arange(n_signals).astype(str).tolist() - info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg') + info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg") data = EpochsArray(data, info) # define some frequencies for tfr @@ -1118,41 +1402,59 @@ def test_spectral_connectivity_time_padding(method, mode, padding): # run connectivity estimation if padding == 5: - with pytest.raises(ValueError, match='Padding cannot be larger than ' - 'half of data length'): + with pytest.raises( + ValueError, match="Padding cannot be larger than " "half of data length" + ): con = spectral_connectivity_time( - data, freqs, sfreq=sfreq, method=method, mode=mode, - n_cycles=5, padding=padding) + data, + freqs, + sfreq=sfreq, + method=method, + mode=mode, + n_cycles=5, + padding=padding, + ) return else: con = spectral_connectivity_time( - data, freqs, sfreq=sfreq, method=method, mode=mode, - n_cycles=5, padding=padding) - - assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs)) - assert con.get_data(output='dense').shape == \ - (n_epochs, n_signals, n_signals, len(con.freqs)) + data, + freqs, + sfreq=sfreq, + method=method, + mode=mode, + n_cycles=5, + padding=padding, + ) + + assert con.shape == (n_epochs, n_signals**2, len(con.freqs)) + assert con.get_data(output="dense").shape == ( + n_epochs, + n_signals, + n_signals, + len(con.freqs), + ) # test the simulated signal triu_inds = np.vstack(np.triu_indices(n_signals, k=1)).T # average over frequencies - conn_data = con.get_data(output='dense').mean(axis=-1) + conn_data = con.get_data(output="dense").mean(axis=-1) # the indices at which there is a correlation should be greater # then the rest of the components for epoch_idx in range(n_epochs): high_conn_val = conn_data[epoch_idx, 0, 1] - assert all(high_conn_val >= conn_data[epoch_idx, idx, jdx] - for idx, jdx in triu_inds) + assert all( + high_conn_val >= conn_data[epoch_idx, idx, jdx] for idx, jdx in triu_inds + ) -@pytest.mark.parametrize('method', ['cacoh', 'mic', 'mim', 'gc', 'gc_tr']) -@pytest.mark.parametrize('average', [True, False]) -@pytest.mark.parametrize('faverage', [True, False]) +@pytest.mark.parametrize("method", ["cacoh", "mic", "mim", "gc", "gc_tr"]) +@pytest.mark.parametrize("average", [True, False]) +@pytest.mark.parametrize("faverage", [True, False]) def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): """Test result shapes of time-resolved multivar. connectivity methods.""" - sfreq = 50. + sfreq = 50.0 n_signals = 4 # Do not change! n_epochs = 8 n_times = 500 @@ -1172,14 +1474,21 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): # check shape of results when averaging across epochs con = spectral_connectivity_time( - data, freqs, indices=indices, method=method, sfreq=sfreq, - faverage=faverage, average=average, gc_n_lags=10) + data, + freqs, + indices=indices, + method=method, + sfreq=sfreq, + faverage=faverage, + average=average, + gc_n_lags=10, + ) assert con.shape == tuple(con_shape) # check shape of MIC patterns are correct - if method in ['cacoh', 'mic']: - for indices_type in ['full', 'ragged']: - if indices_type == 'full': + if method in ["cacoh", "mic"]: + for indices_type in ["full", "ragged"]: + if indices_type == "full": indices = (np.array([[0, 1]]), np.array([[2, 3]])) else: indices = (np.array([[0, 1]]), np.array([[2]])) @@ -1194,25 +1503,34 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): patterns_shape = [2, *patterns_shape] con = spectral_connectivity_time( - data, freqs, indices=indices, method=method, sfreq=sfreq, - faverage=faverage, average=average, gc_n_lags=10) + data, + freqs, + indices=indices, + method=method, + sfreq=sfreq, + faverage=faverage, + average=average, + gc_n_lags=10, + ) - patterns = np.array(con.attrs['patterns']) + patterns = np.array(con.attrs["patterns"]) # 2 (x epochs) x cons x channels x freqs|fbands - assert (patterns.shape == tuple(patterns_shape)) - if indices_type == 'ragged': + assert patterns.shape == tuple(patterns_shape) + if indices_type == "ragged": assert not np.any(np.isnan(patterns[0, ..., :, :])) assert not np.any(np.isnan(patterns[0, ..., 0, :])) assert np.all(np.isnan(patterns[1, ..., 1, :])) # padded entry - assert np.all(np.array(con.indices) == np.array( - (np.array([[0, 1]]), np.array([[2, -1]])))) + assert np.all( + np.array(con.indices) + == np.array((np.array([[0, 1]]), np.array([[2, -1]]))) + ) -@pytest.mark.parametrize('method', ['cacoh', 'mic', 'mim', 'gc', 'gc_tr']) -@pytest.mark.parametrize('mode', ['multitaper', 'cwt_morlet']) +@pytest.mark.parametrize("method", ["cacoh", "mic", "mim", "gc", "gc_tr"]) +@pytest.mark.parametrize("mode", ["multitaper", "cwt_morlet"]) def test_multivar_spectral_connectivity_time_error_catch(method, mode): """Test error catching for time-resolved multivar. connectivity methods.""" - sfreq = 50. + sfreq = 50.0 n_signals = 4 # Do not change! n_epochs = 8 n_times = 256 @@ -1221,132 +1539,178 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): freqs = np.arange(10, 25 + 1) # check bad indices without nested array caught - with pytest.raises(TypeError, - match='multivariate indices must contain array-likes'): + with pytest.raises( + TypeError, match="multivariate indices must contain array-likes" + ): non_nested_indices = (np.array([0, 1]), np.array([2, 3])) - spectral_connectivity_time(data, freqs, method=method, mode=mode, - indices=non_nested_indices, sfreq=sfreq) + spectral_connectivity_time( + data, + freqs, + method=method, + mode=mode, + indices=non_nested_indices, + sfreq=sfreq, + ) # check bad indices with repeated channels caught - with pytest.raises(ValueError, - match='multivariate indices cannot contain repeated'): + with pytest.raises( + ValueError, match="multivariate indices cannot contain repeated" + ): repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) - spectral_connectivity_time(data, freqs, method=method, mode=mode, - indices=repeated_indices, sfreq=sfreq) + spectral_connectivity_time( + data, freqs, method=method, mode=mode, indices=repeated_indices, sfreq=sfreq + ) # check mixed methods caught - with pytest.raises(ValueError, - match='bivariate and multivariate connectivity'): - mixed_methods = [method, 'coh'] - spectral_connectivity_time(data, freqs, method=mixed_methods, - mode=mode, indices=indices, sfreq=sfreq) + with pytest.raises(ValueError, match="bivariate and multivariate connectivity"): + mixed_methods = [method, "coh"] + spectral_connectivity_time( + data, freqs, method=mixed_methods, mode=mode, indices=indices, sfreq=sfreq + ) # check bad rank args caught too_low_rank = (np.array([0]), np.array([0])) - with pytest.raises(ValueError, - match='ranks for seeds and targets must be'): + with pytest.raises(ValueError, match="ranks for seeds and targets must be"): spectral_connectivity_time( - data, freqs, method=method, indices=indices, sfreq=sfreq, - mode=mode, rank=too_low_rank) + data, + freqs, + method=method, + indices=indices, + sfreq=sfreq, + mode=mode, + rank=too_low_rank, + ) too_high_rank = (np.array([3]), np.array([3])) - with pytest.raises(ValueError, - match='ranks for seeds and targets must be'): + with pytest.raises(ValueError, match="ranks for seeds and targets must be"): spectral_connectivity_time( - data, freqs, method=method, indices=indices, sfreq=sfreq, - mode=mode, rank=too_high_rank) + data, + freqs, + method=method, + indices=indices, + sfreq=sfreq, + mode=mode, + rank=too_high_rank, + ) too_few_rank = ([], []) - with pytest.raises(ValueError, match='rank argument must have shape'): + with pytest.raises(ValueError, match="rank argument must have shape"): spectral_connectivity_time( - data, freqs, method=method, indices=indices, sfreq=sfreq, - mode=mode, rank=too_few_rank) + data, + freqs, + method=method, + indices=indices, + sfreq=sfreq, + mode=mode, + rank=too_few_rank, + ) too_much_rank = (np.array([2, 2]), np.array([2, 2])) - with pytest.raises(ValueError, match='rank argument must have shape'): + with pytest.raises(ValueError, match="rank argument must have shape"): spectral_connectivity_time( - data, freqs, method=method, indices=indices, sfreq=sfreq, - mode=mode, rank=too_much_rank) + data, + freqs, + method=method, + indices=indices, + sfreq=sfreq, + mode=mode, + rank=too_much_rank, + ) # check all-to-all conn. computed for MIC/MIM when no indices given - if method in ['cacoh', 'mic', 'mim']: + if method in ["cacoh", "mic", "mim"]: con = spectral_connectivity_time( - data, freqs, method=method, indices=None, sfreq=sfreq, mode=mode) + data, freqs, method=method, indices=None, sfreq=sfreq, mode=mode + ) assert con.indices is None assert con.n_nodes == n_signals - if method == ['cacoh', 'mic']: - assert np.array(con.attrs['patterns']).shape[3] == n_signals + if method == ["cacoh", "mic"]: + assert np.array(con.attrs["patterns"]).shape[3] == n_signals - if method in ['gc', 'gc_tr']: + if method in ["gc", "gc_tr"]: # check no indices caught - with pytest.raises(ValueError, match='indices must be specified'): - spectral_connectivity_time(data, freqs, method=method, mode=mode, - indices=None, sfreq=sfreq) + with pytest.raises(ValueError, match="indices must be specified"): + spectral_connectivity_time( + data, freqs, method=method, mode=mode, indices=None, sfreq=sfreq + ) # check intersecting indices caught bad_indices = (np.array([[0, 1]]), np.array([[0, 2]])) - with pytest.raises(ValueError, - match='seed and target indices must not intersect'): - spectral_connectivity_time(data, freqs, method=method, mode=mode, - indices=bad_indices, sfreq=sfreq) + with pytest.raises( + ValueError, match="seed and target indices must not intersect" + ): + spectral_connectivity_time( + data, freqs, method=method, mode=mode, indices=bad_indices, sfreq=sfreq + ) # check bad fmin/fmax caught - with pytest.raises(ValueError, - match='computing Granger causality on multiple'): - spectral_connectivity_time(data, freqs, method=method, mode=mode, - indices=indices, sfreq=sfreq, - fmin=(5., 15.), fmax=(15., 30.)) + with pytest.raises(ValueError, match="computing Granger causality on multiple"): + spectral_connectivity_time( + data, + freqs, + method=method, + mode=mode, + indices=indices, + sfreq=sfreq, + fmin=(5.0, 15.0), + fmax=(15.0, 30.0), + ) def test_save(tmp_path): """Test saving results of spectral connectivity.""" rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq, f = 10, 3, 2000, 1000., 20. + n_epochs, n_chs, n_times, sfreq, f = 10, 3, 2000, 1000.0, 20.0 data = rng.randn(n_epochs, n_chs, n_times) sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000) data[:, :, 500:1500] += sig - info = create_info(n_chs, sfreq, 'eeg') + info = create_info(n_chs, sfreq, "eeg") tmin = -1 epochs = EpochsArray(data, info, tmin=tmin) conn = spectral_connectivity_epochs( - epochs, fmin=(4, 8, 13, 30), fmax=(8, 13, 30, 45), - faverage=True) - conn.save(tmp_path / 'foo.nc') + epochs, fmin=(4, 8, 13, 30), fmax=(8, 13, 30, 45), faverage=True + ) + conn.save(tmp_path / "foo.nc") def test_multivar_save_load(tmp_path): """Test saving and loading results of multivariate connectivity.""" rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq, f = 5, 4, 2000, 1000., 20. + n_epochs, n_chs, n_times, sfreq, f = 5, 4, 2000, 1000.0, 20.0 data = rng.randn(n_epochs, n_chs, n_times) sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000) data[:, :, 500:1500] += sig - info = create_info(n_chs, sfreq, 'eeg') + info = create_info(n_chs, sfreq, "eeg") tmin = -1 epochs = EpochsArray(data, info, tmin=tmin) - tmp_file = os.path.join(tmp_path, 'foo_mvc.nc') + tmp_file = os.path.join(tmp_path, "foo_mvc.nc") non_ragged_indices = (np.array([[0, 1]]), np.array([[2, 3]])) ragged_indices = (np.array([[0, 1]]), np.array([[2]])) for indices in [non_ragged_indices, ragged_indices]: con = spectral_connectivity_epochs( - epochs, method=['cacoh', 'mic', 'mim', 'gc', 'gc_tr'], - indices=indices, sfreq=sfreq, fmin=10, fmax=30) + epochs, + method=["cacoh", "mic", "mim", "gc", "gc_tr"], + indices=indices, + sfreq=sfreq, + fmin=10, + fmax=30, + ) for this_con in con: this_con.save(tmp_file) read_con = read_connectivity(tmp_file) - assert_array_almost_equal(this_con.get_data(), - read_con.get_data('raveled')) - if this_con.attrs['patterns'] is not None: - assert_array_almost_equal(np.array(this_con.attrs['patterns']), - np.array(read_con.attrs['patterns'])) + assert_array_almost_equal(this_con.get_data(), read_con.get_data("raveled")) + if this_con.attrs["patterns"] is not None: + assert_array_almost_equal( + np.array(this_con.attrs["patterns"]), + np.array(read_con.attrs["patterns"]), + ) # split `repr` before the file size (`~23 kB` for example) - a = repr(this_con).split('~')[0] - b = repr(read_con).split('~')[0] + a = repr(this_con).split("~")[0] + b = repr(read_con).split("~")[0] assert a == b -@pytest.mark.parametrize("method", ['coh', 'plv', 'pli', 'wpli', 'ciplv']) -@pytest.mark.parametrize("indices", [None, - (np.array([0, 1]), np.array([2, 3]))]) +@pytest.mark.parametrize("method", ["coh", "plv", "pli", "wpli", "ciplv"]) +@pytest.mark.parametrize("indices", [None, (np.array([0, 1]), np.array([2, 3]))]) def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): """Test that indices values and type is maintained after saving. @@ -1386,12 +1750,9 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): assert con.indices is None and read_con.indices is None -@pytest.mark.parametrize("method", ['cacoh', 'mic', 'mim', 'gc', 'gc_tr']) -@pytest.mark.parametrize("indices", [None, - (np.array([[0, 1]]), np.array([[2, 3]]))]) -def test_multivar_spectral_connectivity_indices_roundtrip_io( - tmp_path, method, indices -): +@pytest.mark.parametrize("method", ["cacoh", "mic", "mim", "gc", "gc_tr"]) +@pytest.mark.parametrize("indices", [None, (np.array([[0, 1]]), np.array([[2, 3]]))]) +def test_multivar_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): """Test that indices values and type is maintained after saving. If `indices` is None, `indices` in the returned connectivity object should @@ -1408,17 +1769,21 @@ def test_multivar_spectral_connectivity_indices_roundtrip_io( tmp_file = os.path.join(tmp_path, "foo_mvc.nc") # test the pair of method and indices defined to check the output indices - if indices is None and method in ['gc', 'gc_tr']: + if indices is None and method in ["gc", "gc_tr"]: # indicesmust be specified for GC pytest.skip() con_epochs = spectral_connectivity_epochs( - epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30, - gc_n_lags=10 + epochs, + method=method, + indices=indices, + sfreq=sfreq, + fmin=10, + fmax=30, + gc_n_lags=10, ) con_time = spectral_connectivity_time( - epochs, freqs, method=method, indices=indices, sfreq=sfreq, - gc_n_lags=10 + epochs, freqs, method=method, indices=indices, sfreq=sfreq, gc_n_lags=10 ) for con in [con_epochs, con_time]: @@ -1431,10 +1796,16 @@ def test_multivar_spectral_connectivity_indices_roundtrip_io( read_con.indices, tuple ) # check indices are masked - assert all([np.ma.isMA(inds) for inds in con.indices] and - [np.ma.isMA(inds) for inds in read_con.indices]) + assert all( + [np.ma.isMA(inds) for inds in con.indices] + and [np.ma.isMA(inds) for inds in read_con.indices] + ) # check indices have same values - assert np.all([con_inds == read_inds for con_inds, read_inds in - zip(con.indices, read_con.indices)]) + assert np.all( + [ + con_inds == read_inds + for con_inds, read_inds in zip(con.indices, read_con.indices) + ] + ) else: assert con.indices is None and read_con.indices is None From 788dddf03090a310cfb7fc10e8729ec3f302954d Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 11 Dec 2023 14:40:45 +0100 Subject: [PATCH 13/59] [MAINT] Ignore black formatting in git blame --- .git-blame-ignore-revs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 0dbf97cb..30258f97 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -1 +1,2 @@ -0c216d35127f48792caabdc3cca170874a443eee # black, isort, ruff \ No newline at end of file +0c216d35127f48792caabdc3cca170874a443eee # black, isort, ruff +e93fb47b3c5bb20164d5bb347b40692a52cfe430 # black \ No newline at end of file From 4296e5488150eca791be5ebe7057af91eb001dc4 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 11 Dec 2023 15:05:00 +0100 Subject: [PATCH 14/59] [BUG] Fix missing/incorrect multivariate tests --- .../spectral/tests/test_spectral.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 609e4108..212d11fc 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -716,9 +716,15 @@ def test_multivariate_spectral_connectivity_epochs_regression(): n_jobs=1, ) - # should take the absolute of the MIC scores, as the MATLAB implementation - # returns the absolute values. - mne_results = {this_con.method: np.abs(this_con.get_data()) for this_con in con} + mne_results = {} + for this_con in con: + # must take the absolute of the MIC scores, as the MATLAB + # implementation returns the absolute values. + if this_con.method == "mic": + mne_results[this_con.method] = np.abs(this_con.get_data()) + else: + mne_results[this_con.method] = this_con.get_data() + matlab_results = pd.read_pickle( os.path.join(fpath, "data", "example_multivariate_matlab_results.pkl") ) @@ -1146,10 +1152,10 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): ) con_matrix = con.get_data() - # MIC values can be pos. and neg., so must be averaged after taking the - # absolute values for the test to work + # CaCoh/MIC values can be pos. and neg., so must be averaged after taking + # the absolute values for the test to work if method in multivar_methods: - if method == "mic": + if method in ["cacoh", "mic"]: con_matrix = np.mean(np.abs(con_matrix), axis=(0, 2)) assert con.shape == (n_epochs, 1, len(con.freqs)) else: From 888e82d80339010f1a1229b9b95437e356da62f7 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 11 Dec 2023 15:05:20 +0100 Subject: [PATCH 15/59] [MAINT] Improve test code readability --- .../spectral/tests/test_spectral.py | 79 +++++++++---------- 1 file changed, 37 insertions(+), 42 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 212d11fc..7a064e96 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -593,13 +593,11 @@ def test_spectral_connectivity_epochs_multivariate(method): assert np.array(con.attrs["patterns"]).shape[2] == n_signals # check ragged indices padded correctly - ragged_indices = (np.array([[0]]), np.array([[1, 2]])) + ragged_indices = ([[0]], [[1, 2]]) con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=ragged_indices, sfreq=sfreq ) - assert np.all( - np.array(con.indices) == np.array([np.array([[0, -1]]), np.array([[1, 2]])]) - ) + assert np.all(np.array(con.indices) == np.array([[[0, -1]], [[1, 2]]])) # check shape of CaCoh/MIC patterns if method in ["cacoh", "mic"]: @@ -648,7 +646,7 @@ def test_spectral_connectivity_epochs_multivariate(method): assert np.shape(con.attrs["patterns"][1][0])[1] == len(fmin) # check patterns shape matches input data, not rank - rank = (np.array([1]), np.array([1])) + rank = ([1], [1]) con = spectral_connectivity_epochs( data, method=method, @@ -661,7 +659,7 @@ def test_spectral_connectivity_epochs_multivariate(method): assert np.shape(con.attrs["patterns"][1][0])[0] == n_targets # check patterns padded correctly - ragged_indices = (np.array([[0]]), np.array([[1, 2]])) + ragged_indices = ([[0]], [[1, 2]]) con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=ragged_indices, sfreq=sfreq ) @@ -696,7 +694,7 @@ def test_multivariate_spectral_connectivity_epochs_regression(): fpath = os.path.dirname(os.path.realpath(__file__)) data = pd.read_pickle(os.path.join(fpath, "data", "example_multivariate_data.pkl")) sfreq = 100 - indices = (np.array([[0, 1]]), np.array([[2, 3]])) + indices = ([[0, 1]], [[2, 3]]) methods = ["cacoh", "mic", "mim", "gc", "gc_tr"] con = spectral_connectivity_epochs( data, @@ -745,14 +743,14 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): n_times = 256 rng = np.random.RandomState(0) data = rng.randn(n_epochs, n_signals, n_times) - indices = (np.array([[0, 1]]), np.array([[2, 3]])) + indices = ([[0, 1]], [[2, 3]]) cwt_freqs = np.arange(10, 25 + 1) # check bad indices without nested array caught with pytest.raises( TypeError, match="multivariate indices must contain array-likes" ): - non_nested_indices = (np.array([0, 1]), np.array([2, 3])) + non_nested_indices = ([0, 1], [2, 3]) spectral_connectivity_epochs( data, method=method, @@ -766,7 +764,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): with pytest.raises( ValueError, match="multivariate indices cannot contain repeated" ): - repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) + repeated_indices = ([[0, 1, 1]], [[2, 2, 3]]) spectral_connectivity_epochs( data, method=method, @@ -792,7 +790,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): ) # check bad rank args caught - too_low_rank = (np.array([0]), np.array([0])) + too_low_rank = ([0], [0]) with pytest.raises(ValueError, match="ranks for seeds and targets must be"): spectral_connectivity_epochs( data, @@ -803,7 +801,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): rank=too_low_rank, cwt_freqs=cwt_freqs, ) - too_high_rank = (np.array([3]), np.array([3])) + too_high_rank = ([3], [3]) with pytest.raises(ValueError, match="ranks for seeds and targets must be"): spectral_connectivity_epochs( data, @@ -825,7 +823,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): rank=too_few_rank, cwt_freqs=cwt_freqs, ) - too_much_rank = (np.array([2, 2]), np.array([2, 2])) + too_much_rank = ([2, 2], [2, 2]) with pytest.raises(ValueError, match="rank argument must have shape"): spectral_connectivity_epochs( data, @@ -853,7 +851,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): gc_n_lags=10, cwt_freqs=cwt_freqs, ) - assert rank_con.attrs["rank"] == (np.array([1]), np.array([1])) + assert rank_con.attrs["rank"] == ([1], [1]) if method in ["cacoh", "mic", "mim"]: # check rank-deficient transformation matrix caught @@ -864,7 +862,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): mode=mode, indices=indices, sfreq=sfreq, - rank=(np.array([2]), np.array([2])), + rank=([2], [2]), cwt_freqs=cwt_freqs, ) @@ -898,7 +896,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): ) # check intersecting indices caught - bad_indices = (np.array([[0, 1]]), np.array([[0, 2]])) + bad_indices = ([[0, 1]], [[0, 2]]) with pytest.raises( ValueError, match="seed and target indices must not intersect" ): @@ -932,7 +930,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): mode=mode, indices=indices, sfreq=sfreq, - rank=(np.array([2]), np.array([2])), + rank=([2], [2]), cwt_freqs=cwt_freqs, ) @@ -946,7 +944,7 @@ def test_multivar_spectral_connectivity_parallel(method): n_times = 256 rng = np.random.RandomState(0) data = rng.randn(n_epochs, n_signals, n_times) - indices = (np.array([[0, 1]]), np.array([[2, 3]])) + indices = ([[0, 1]], [[2, 3]]) spectral_connectivity_epochs( data, @@ -981,9 +979,9 @@ def test_multivar_spectral_connectivity_flipped_indices(): # if we're not careful, when finding the channels we need to compute the # CSD for, we might accidentally reorder the connectivity indices - indices = (np.array([[0, 1]]), np.array([[2, 3]])) - flipped_indices = (np.array([[2, 3]]), np.array([[0, 1]])) - concat_indices = (np.array([[0, 1], [2, 3]]), np.array([[2, 3], [0, 1]])) + indices = ([[0, 1]], [[2, 3]]) + flipped_indices = ([[2, 3]], [[0, 1]]) + concat_indices = ([[0, 1], [2, 3]], [[2, 3], [0, 1]]) # we test on GC since this is a directed connectivity measure method = "gc" @@ -1146,8 +1144,8 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): fmin=freq_band_low_limit, fmax=freq_band_high_limit, n_jobs=1, - faverage=True if method != "mic" else False, - average=True if method != "mic" else False, + faverage=method not in ["cacoh", "mic"], + average=method not in ["cacoh", "mic"], sm_times=0, ) con_matrix = con.get_data() @@ -1195,7 +1193,7 @@ def test_spectral_connectivity_time_delayed(): trans_bandwidth = 2.0 # Hz delay = 5 # samples (non-zero delay needed for GC to be >> 0) - indices = (np.array([[0, 1]]), np.array([[2, 3]])) + indices = ([[0, 1]], [[2, 3]]) # 20-30 Hz connectivity fstart, fend = 20.0, 30.0 @@ -1467,7 +1465,7 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): n_times = 500 rng = np.random.RandomState(0) data = rng.randn(n_epochs, n_signals, n_times) - indices = (np.array([[0, 1]]), np.array([[2, 3]])) + indices = ([[0, 1]], [[2, 3]]) n_cons = len(indices[0]) freqs = np.arange(10, 25 + 1) @@ -1496,9 +1494,9 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): if method in ["cacoh", "mic"]: for indices_type in ["full", "ragged"]: if indices_type == "full": - indices = (np.array([[0, 1]]), np.array([[2, 3]])) + indices = ([[0, 1]], [[2, 3]]) else: - indices = (np.array([[0, 1]]), np.array([[2]])) + indices = ([[0, 1]], [[2]]) max_n_chans = 2 patterns_shape = [n_cons, max_n_chans] if faverage: @@ -1527,10 +1525,7 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): assert not np.any(np.isnan(patterns[0, ..., :, :])) assert not np.any(np.isnan(patterns[0, ..., 0, :])) assert np.all(np.isnan(patterns[1, ..., 1, :])) # padded entry - assert np.all( - np.array(con.indices) - == np.array((np.array([[0, 1]]), np.array([[2, -1]]))) - ) + assert np.all(np.array(con.indices) == np.array(([[0, 1]], [[2, -1]]))) @pytest.mark.parametrize("method", ["cacoh", "mic", "mim", "gc", "gc_tr"]) @@ -1542,14 +1537,14 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): n_epochs = 8 n_times = 256 data = np.random.rand(n_epochs, n_signals, n_times) - indices = (np.array([[0, 1]]), np.array([[2, 3]])) + indices = ([[0, 1]], [[2, 3]]) freqs = np.arange(10, 25 + 1) # check bad indices without nested array caught with pytest.raises( TypeError, match="multivariate indices must contain array-likes" ): - non_nested_indices = (np.array([0, 1]), np.array([2, 3])) + non_nested_indices = ([0, 1], [2, 3]) spectral_connectivity_time( data, freqs, @@ -1563,7 +1558,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): with pytest.raises( ValueError, match="multivariate indices cannot contain repeated" ): - repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) + repeated_indices = ([[0, 1, 1]], [[2, 2, 3]]) spectral_connectivity_time( data, freqs, method=method, mode=mode, indices=repeated_indices, sfreq=sfreq ) @@ -1576,7 +1571,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): ) # check bad rank args caught - too_low_rank = (np.array([0]), np.array([0])) + too_low_rank = ([0], [0]) with pytest.raises(ValueError, match="ranks for seeds and targets must be"): spectral_connectivity_time( data, @@ -1587,7 +1582,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): mode=mode, rank=too_low_rank, ) - too_high_rank = (np.array([3]), np.array([3])) + too_high_rank = ([3], [3]) with pytest.raises(ValueError, match="ranks for seeds and targets must be"): spectral_connectivity_time( data, @@ -1609,7 +1604,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): mode=mode, rank=too_few_rank, ) - too_much_rank = (np.array([2, 2]), np.array([2, 2])) + too_much_rank = ([2, 2], [2, 2]) with pytest.raises(ValueError, match="rank argument must have shape"): spectral_connectivity_time( data, @@ -1639,7 +1634,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): ) # check intersecting indices caught - bad_indices = (np.array([[0, 1]]), np.array([[0, 2]])) + bad_indices = ([[0, 1]], [[0, 2]]) with pytest.raises( ValueError, match="seed and target indices must not intersect" ): @@ -1690,8 +1685,8 @@ def test_multivar_save_load(tmp_path): epochs = EpochsArray(data, info, tmin=tmin) tmp_file = os.path.join(tmp_path, "foo_mvc.nc") - non_ragged_indices = (np.array([[0, 1]]), np.array([[2, 3]])) - ragged_indices = (np.array([[0, 1]]), np.array([[2]])) + non_ragged_indices = ([[0, 1]], [[2, 3]]) + ragged_indices = ([[0, 1]], [[2]]) for indices in [non_ragged_indices, ragged_indices]: con = spectral_connectivity_epochs( epochs, @@ -1717,7 +1712,7 @@ def test_multivar_save_load(tmp_path): @pytest.mark.parametrize("method", ["coh", "plv", "pli", "wpli", "ciplv"]) -@pytest.mark.parametrize("indices", [None, (np.array([0, 1]), np.array([2, 3]))]) +@pytest.mark.parametrize("indices", [None, ([0, 1], [2, 3])]) def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): """Test that indices values and type is maintained after saving. @@ -1758,7 +1753,7 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): @pytest.mark.parametrize("method", ["cacoh", "mic", "mim", "gc", "gc_tr"]) -@pytest.mark.parametrize("indices", [None, (np.array([[0, 1]]), np.array([[2, 3]]))]) +@pytest.mark.parametrize("indices", [None, ([[0, 1]], [[2, 3]])]) def test_multivar_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): """Test that indices values and type is maintained after saving. From ec25d7d676d31cdc50cd0ba5aaceca8e2ce12608 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 11 Dec 2023 15:45:23 +0100 Subject: [PATCH 16/59] Switched to complex CaCoh --- .../spectral/epochs_multivariate.py | 13 ++++++++++--- .../example_multivariate_matlab_results.pkl | Bin 4147 -> 4916 bytes mne_connectivity/spectral/time.py | 7 ++++++- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index b138c425..6b76c42b 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -104,6 +104,7 @@ class _EpochMeanMultivariateConEstBase(_AbstractConEstBase): n_steps = None patterns = None + con_scores_dtype = np.float64 def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): self.n_signals = n_signals @@ -115,10 +116,14 @@ def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): # include time dimension, even when unused for indexing flexibility if n_times == 0: self.csd_shape = (n_signals**2, n_freqs) - self.con_scores = np.zeros((n_cons, n_freqs, 1)) + self.con_scores = np.zeros( + (n_cons, n_freqs, 1), dtype=self.con_scores_dtype + ) else: self.csd_shape = (n_signals**2, n_freqs, n_times) - self.con_scores = np.zeros((n_cons, n_freqs, n_times)) + self.con_scores = np.zeros( + (n_cons, n_freqs, n_times), dtype=self.con_scores_dtype + ) # allocate space for accumulation of CSD self._acc = np.zeros(self.csd_shape, dtype=np.complex128) @@ -469,6 +474,7 @@ class _CaCohEst(_MultivariateCohEstBase): """ name = "CaCoh" + con_scores_dtype = np.complex128 # CaCoh is complex-valued def _compute_con_daughter( self, seed_idcs, target_idcs, C, C_bar, U_bar_aa, U_bar_bb, con_i @@ -496,7 +502,8 @@ def _compute_con_daughter( C_bar_ab, T_aa, T_bb, max_coh, max_phis ) - self.con_scores[con_i] = max_coh.T + # Store connectivity score as complex value + self.con_scores[con_i] = (max_coh * np.exp(-1j * max_phis)).T self._compute_patterns( max_phis, diff --git a/mne_connectivity/spectral/tests/data/example_multivariate_matlab_results.pkl b/mne_connectivity/spectral/tests/data/example_multivariate_matlab_results.pkl index dd1e2bcb4e9056ace9fa5325afcc582694a97096..85bc2f84797d4158093723ab766c51b57b293505 100644 GIT binary patch delta 1582 zcmXBVjXM;00tfISnJ{*44P|r1a~368w2CQTq21b8M-o~_UP_o^qrAlA+-^3?dO0u4 z$otF5+Zcu!({E;&W^4^jl-w?DXt|VVL(1BFp1sdM@o6{anQ2_n_}31-HjOtr?2YUG zG`G3-vbsYjju$*Hbq(mod76R!fqWi9Ro5$1R3gMh47CkXJE4??b6-vNf}K5Va?6+x z?ml9nX(@ul2$No%YJ8MiZe97h8z=044NKZANA+0)zEcw=RIxXSKjkOM<)pEssYNni13yn(U^tE^#N$jzNw755U zFl(r!cWoo&i-&8EXt7O5buZcT+jtN8o7+#Rjv|7^D>X0UdK-zYeW|0VTmXTo=HKIe zJ4o(U|0A;2Mv&mNElx`U;>kE4KRr{;M%d+~iKu=d`7br zSlR(r@8f^Jpb}(LJrAwlSd3X+|3&>`2|6zFZ@rD?;U`ASuaO>n_&kd~aJr}wRjb9U z*pW7@xPF_kGR4CMwV&;mg(Yy6MI@IeiSbPVT`i4ynb66%1~V*nBZBQ_x;pg;b`~8A zv0LAXxQxBgFD+&85_p#fP782HcAW$eED!Mu=^aI~~;P{jz`#{6W@8wt_um=O0Fi4ieYP{})8N0Mrv zGOgI9V!Td%UT@x4OO$r|&b9m@!l~dQw};6b-1hzYUD%FR%&RY3#XhQozq5)p+$cf9 z9{H$7LkW5woO3c+6~nu5|LhWJB@!mf%l=X~8sBGo{g&E9!sCmA=L~Yda;_h`CGRF_ zMtA+;j)g`1p?NhtD5DCI;6VHdeN(bY__3z&r zD&X&#tah{_2&U2%o=>72Ntu}*M}_4?kW_s54x=AbMQMNA;0>~ROU+R1n#84d>apR`hAF431R% z2&LCt%ssgrLuUDjJI9)_)V^~n-@PBS3$YCwoVetS==C>cp+baYS#@&Oq~yCb_Ghuw zl^l!}JMxVLGGghp*Oq9C(C~+B)etd${6}?<+Yu=1f1dFvNqX)uWOj zSCWkHWBdYFi{OwQ-)%oDBmITbcXV!Fg=*4$S!Y5?x<7lNsbUg*lH``!k@f#FSWQ!di+{uCS04attDzA z%TPfcgfTwV*RDWpJ6BJVEhi^Ub1c1FQlYkWa+H#A7iMcL_S@b@d^~pGbK-H2occvE zY&BH{i`c=BWExbBF5}R>!lM6DfWivmrb%eFcOdu@i0b&tCe)&a}5DG zII5x)y0)W>?dB5wQbtTOPgy(t(Twm~Q^N>zDWT;Z-4?W%gTPD6F$1<@GO;nzX05Xb z`oHuY_o1+5#NaE7YsJrMF(paL=MDCf2FH({tz^c)Z{FY{Wn2Q~4=#_7>2#ri7hs;w z>VnpeSx%x#jNOZeDOa9!kmz^r+hyiF2$Idak|3ocXL)4nB8PO@JROd(66609tXr2KLCI1NwXx6Br0x12 D>IxS< delta 807 zcmV+?1K9kuCbJ*~fCZH(5U~aH41WUv000000H%ZAPgUE%KmN-4XLudLKSs2|WfZ=~ zKl`H8puC*IKg_!?MuzLVKWwYw)QB0!Kaak^1urVXKYb94ztz^oKXcvYjWnvlKfq3g zF)0YVKjx+wFK>juKPy2VuOT$RKQznS$}v*GKa89-Gp!!MKNB-YVi`}qKYw9&{trg5 zy+1w!_N}T)v_GVXwo}0M!atLrhE%fO!9R8SkTJddzCSEha)9TJwm(c{0vZ?^4`loO~>S1Jpjo+I+9u4rR2&#ac(ti zq-w`MYXVf1#Mi|?SIqN%+$P6AT+2wv>GsDzhy%R@^^4CxDB7(v=4TKXRNR#P?ynKY9EMfk$JyKO%Vc8KsBCKYiBvY**~Y zKiYNEP(+-^KMC7QX Date: Mon, 11 Dec 2023 16:07:14 +0100 Subject: [PATCH 17/59] [BUG] Fix failing tests h5py version mismatch --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index 07c9302b..3108c7a5 100644 --- a/environment.yml +++ b/environment.yml @@ -23,3 +23,4 @@ dependencies: - pyqt!=5.15.3,!=5.15.4 - mne>=1.3 - h5netcdf + - h5py From 488c0d8cb01a5647e14397c906f8cf2ec53a0655 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 11 Dec 2023 16:10:04 +0100 Subject: [PATCH 18/59] Revert "[BUG] Fix failing tests h5py version mismatch" This reverts commit 3d281ff815880101d532b63ed7ff4bb605425080. --- environment.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/environment.yml b/environment.yml index 3108c7a5..07c9302b 100644 --- a/environment.yml +++ b/environment.yml @@ -23,4 +23,3 @@ dependencies: - pyqt!=5.15.3,!=5.15.4 - mne>=1.3 - h5netcdf - - h5py From f9ff2ac7f1baea1c71d989e909fa5efae3df6b04 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 11 Dec 2023 20:19:19 +0100 Subject: [PATCH 19/59] Added author --- mne_connectivity/spectral/epochs_multivariate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index 6b76c42b..e9a1a351 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -5,6 +5,7 @@ # Tien D. Nguyen # Richard M. Köhler # Mohammad Orabe +# Mina Jamshidi Idaji # # License: BSD (3-clause) From 1cb52ad48de64d06b7c11df93b4e18892e9efbca Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Fri, 15 Dec 2023 13:03:10 +0100 Subject: [PATCH 20/59] Revert "[MAINT] Ignore black formatting in git blame" This reverts commit 788dddf03090a310cfb7fc10e8729ec3f302954d. --- .git-blame-ignore-revs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 30258f97..0dbf97cb 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -1,2 +1 @@ -0c216d35127f48792caabdc3cca170874a443eee # black, isort, ruff -e93fb47b3c5bb20164d5bb347b40692a52cfe430 # black \ No newline at end of file +0c216d35127f48792caabdc3cca170874a443eee # black, isort, ruff \ No newline at end of file From cc81cc47d2f2659a09e1719fa5ed61f801ccc7fa Mon Sep 17 00:00:00 2001 From: Mohammad Date: Wed, 31 Jan 2024 14:33:52 +0100 Subject: [PATCH 21/59] Update changelog file --- doc/whats_new.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 0199231e..9a4295c7 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -24,7 +24,7 @@ Version 0.7 (in dev) Enhancements ~~~~~~~~~~~~ -- +- Add support for a new multivariate connectivity method (canonical coherence; ``CaCoh``) in :func:`mne_connectivity.spectral_connectivity_epochs` and :func:`mne_connectivity.spectral_connectivity_time` by `Thomas Binns`_ and `Mohammad Orabe`_ (:pr:`163`). Bug ~~~ @@ -38,6 +38,8 @@ API Authors ~~~~~~~ +* `Thomas Binns`_ +* `Mohammad Orabe`_ :doc:`Find out what was new in previous releases ` From 11f7775e3357ad52d47aaf21b5b27ff636e7392e Mon Sep 17 00:00:00 2001 From: Mohammad Date: Wed, 31 Jan 2024 14:37:41 +0100 Subject: [PATCH 22/59] Add new authors to CITATION file --- CITATION.cff | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CITATION.cff b/CITATION.cff index 96be26d1..996e7fa1 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -20,6 +20,13 @@ authors: - given-names: "Alexandre" family-names: "Gramfort" orcid: https://orcid.org/0000-0001-9791-4404 + - given-names: "Thomas Samuel" + family-names: "Binns" + orcid: https://orcid.org/0000-0003-0657-0891 + - given-names: "Mohammad" + family-names: "Orabe" + orcid: https://orcid.org/0009-0004-7177-799X + title: "mne-connectivity" version: 0.2.0 From 513e15a179f6e7c338911bfa712566772763d81c Mon Sep 17 00:00:00 2001 From: Mohammad Date: Wed, 31 Jan 2024 14:39:25 +0100 Subject: [PATCH 23/59] Update authors file --- doc/authors.inc | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/authors.inc b/doc/authors.inc index 21a3fad4..14f36b11 100644 --- a/doc/authors.inc +++ b/doc/authors.inc @@ -14,3 +14,4 @@ .. _Thomas Binns: https://github.com/tsbinns .. _Tien Nguyen: https://github.com/nguyen-td .. _Richard Köhler: https://github.com/richardkoehler +.. _Mohammad Orabe: https://github.com/orabe From 2a2a8ed498550af424636545d883d602a7452528 Mon Sep 17 00:00:00 2001 From: Mohammad Date: Wed, 31 Jan 2024 14:42:07 +0100 Subject: [PATCH 24/59] add a new author --- doc/authors.inc | 1 + doc/whats_new.rst | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/authors.inc b/doc/authors.inc index 14f36b11..61ead35b 100644 --- a/doc/authors.inc +++ b/doc/authors.inc @@ -15,3 +15,4 @@ .. _Tien Nguyen: https://github.com/nguyen-td .. _Richard Köhler: https://github.com/richardkoehler .. _Mohammad Orabe: https://github.com/orabe +.. _Mina Jamshidi: https://github.com/minajamshidi diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 9a4295c7..90365443 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -24,7 +24,7 @@ Version 0.7 (in dev) Enhancements ~~~~~~~~~~~~ -- Add support for a new multivariate connectivity method (canonical coherence; ``CaCoh``) in :func:`mne_connectivity.spectral_connectivity_epochs` and :func:`mne_connectivity.spectral_connectivity_time` by `Thomas Binns`_ and `Mohammad Orabe`_ (:pr:`163`). +- Add support for a new multivariate connectivity method (canonical coherence; ``CaCoh``) in :func:`mne_connectivity.spectral_connectivity_epochs` and :func:`mne_connectivity.spectral_connectivity_time` by `Thomas Binns`_ and `Mohammad Orabe`_ and `Mina Jamshidi`_ (:pr:`163`). Bug ~~~ @@ -40,6 +40,7 @@ Authors ~~~~~~~ * `Thomas Binns`_ * `Mohammad Orabe`_ +* `Mina Jamshidi`_ :doc:`Find out what was new in previous releases ` From 257a324b43f6e4052a4e9fd54ef5a336414941c1 Mon Sep 17 00:00:00 2001 From: Mohammad Date: Wed, 31 Jan 2024 15:18:02 +0100 Subject: [PATCH 25/59] Add docstring for _invsqrtm function --- .../spectral/epochs_multivariate.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index e9a1a351..f8e35284 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -321,6 +321,23 @@ def reshape_results(self): def _invsqrtm(C, T, n_seeds): """Compute inverse sqrt of CSD over times (used for CaCoh, MIC, & MIM). + Parameters + ---------- + C : np.ndarray, shape=(n_times, n_channels, n_channels) + CSD for a single frequency and all times (n_times=1 if the mode is not + time-frequency resolved, e.g. multitaper). + T : np.ndarray, shape=(n_times, n_channels, n_channels) + Empty array to store the inverse square root of the CSD in. + n_seeds : int + Number of seed channels for the connection. + + Returns + ------- + T : np.ndarray, shape=(n_times, n_channels, n_channels) + Inverse square root of the CSD. Name comes from Ewald et al. (2012). + + Notes + ----- Kept as a standalone function to allow for parallelisation over CSD frequencies. From 45458e94b54f16e18914ca8ba1a7e692b6989220 Mon Sep 17 00:00:00 2001 From: Mohammad Date: Wed, 31 Jan 2024 15:26:48 +0100 Subject: [PATCH 26/59] Update docstring for _MultivariateCohEstBase --- mne_connectivity/spectral/epochs_multivariate.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index f8e35284..b64415a7 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -180,8 +180,11 @@ class _MultivariateCohEstBase(_EpochMeanMultivariateConEstBase): """Base estimator for multivariate coherency methods. See: - - Ewald et al. (2012). NeuroImage. DOI: 10.1016/j.neuroimage.2011.11.084 - - Vidaurre et al. (2019). NeuroImage. DOI: 10.1016/j.neuroimage.2019.116009 + - Imaginary part of coherency, i.e. multivariate imaginary part of coherency (MIC) + and multivariate interaction measure (MIM): Ewald et al. (2012). NeuroImage. DOI: + 10.1016/j.neuroimage.2011.11.084 + - Coherency/coherence, i.e. canonical coherence (CaCoh): Vidaurre et al. (2019). + NeuroImage. DOI: 10.1016/j.neuroimage.2019.116009 """ name: Optional[str] = None From f8862976b2129e4ce72a5bf040a62d0008cdec5e Mon Sep 17 00:00:00 2001 From: Mohammad Date: Fri, 2 Feb 2024 06:06:20 +0100 Subject: [PATCH 27/59] add draft for cacoh example --- examples/cacoh.py | 493 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 493 insertions(+) create mode 100644 examples/cacoh.py diff --git a/examples/cacoh.py b/examples/cacoh.py new file mode 100644 index 00000000..e3785625 --- /dev/null +++ b/examples/cacoh.py @@ -0,0 +1,493 @@ +""" +==================================================== +Compute multivariate measure of (absolute) coherence +==================================================== + +This example showcases the application of the Canonical Coherence (CaCoh) method, as detailed :footcite: `Vidaurre et al. (2019)`, for detecting neural synchronization in multivariate signal spaces. + +The method maximizes the absolute value of the coherence between two sets of multivariate spaces directly in the frequency domain. For each frequency bin two spatial filters are computed in order to maximize the coherence between the projected components. +""" + +# Authors: Mohammad Orabe +# Thomas S. Binns +# +# License: BSD (3-clause) + +# %% +import numpy as np +from matplotlib import pyplot as plt + +import mne +import mne_connectivity + +############################################################################### +# Background +# ---------- +# +# Multivariate connectivity methods have emerged as a sophisticated approach to +# understand the complex interplay between multiple brain regions simultaneously. These +# methods transcend the limitations of bivariate analyses, which focus on pairwise +# interactions, by offering a holistic view of brain network dynamics. However, +# challenges such as volume conduction and source mixing—where signals from distinct +# neural sources blend or appear falsely synchronized due to the conductive properties +# of brain tissue, complicate the interpretation of connectivity data. While other +# multivariate methods like `MIC` specifically exclude zero time-lag interactions to +# avoid volume conduction artifacts, assuming these to be non-physiological, the +# Canonical Coherence (Cacoh) method can capture and analyze interactions between +# signals with both zero and non-zero time-lag. +# +# This capability allows Cacoh to identify neural interactions that occur +# simultaneously, offering insights into connectivity that may be overlooked by other +# methods. The Cacoh method utilizes spatial filters to isolate true neural +# interactions. This approach not only simplifies the interpretation of complex neural +# interactions but also potentially enhances the signal-to-noise ratio, offering +# methodological advantages such as reduced bias from source mixing. +# +# CaCoh maximizes the absolute value of the coherence between the two multivariate +# spaces directly in the frequency domain, where for each frequency bin, two spatial +# filters are computed in order to maximize the coherence between the projected +# components. + + +############################################################################### +# Data Simulation +# --------------- +# +# The CaCoh method can be used to investigate the synchronization between two +# modalities. For instance it can be applied to optimize (in a non-invasive way) the +# cortico-muscular coherence between a central set of electroencephalographic (EEG) +# sensors and a peripheral set of electromyographic (EMG) electrodes, where each +# subspaces is multivariate CaCoh is capable of taking into account the fact that +# cortico-spinal interactions are multivariate in nature not only on the cortical level +# but also at the level of the spinal cord, where multiple afferent and efferent +# processes occur. +# +# CaCoh extends beyond analyzing cortico-muscular interactions. It is also applicable +# in various EEG/MEG/LFP research scenarios, such as studying the interactions between +# cortical (EEG/MEG) and subcortical activities, or in examining intracortical local +# field potentials (LFP). +# +# In this demo script, we will generates synthetic EEG signals. Let's define a function +# that enables the study of both zero-lag and non-zero-lag interactions by adjusting +# the parameter `connection_delay`. + + +# %% +def simulate_connectivity( + n_seeds: int, + n_targets: int, + freq_band: tuple[int, int], + n_epochs: int, + n_times: int, + sfreq: int, + snr: float, + connection_delay, + rng_seed: int | None = None, +) -> mne.Epochs: + """Simulates signals interacting in a given frequency band. + + Parameters + ---------- + n_seeds : int + Number of seed channels to simulate. + + n_targets : int + Number of target channels to simulate. + + freq_band : tuple of int, int + Frequency band where the connectivity should be simulated, where the first entry corresponds + to the lower frequency, and the second entry to the higher frequency. + + n_epochs : int + Number of epochs in the simulated data. + + n_times : int + Number of timepoints each epoch of the simulated data. + + sfreq : int + Sampling frequency of the simulated data, in Hz. + + snr : float + Signal-to-noise ratio of the simulated data. + + connection_delay : + Number of timepoints for the delay of connectivity between the seeds and targets. If > 0, + the target data is a delayed form of the seed data by this many timepoints. + + rng_seed : int | None (default None) + Seed to use for the random number generator. If `None`, no seed is specified. + + Returns + ------- + epochs : mne.Epochs + The simulated data stored in an Epochs object. The channels are arranged according to seeds, + then targets. + """ + if rng_seed is not None: + np.random.seed(rng_seed) + + n_channels = n_seeds + n_targets + trans_bandwidth = 1 # Hz + + # simulate signal source at desired frequency band + signal = np.random.randn(1, n_epochs * n_times + connection_delay) + signal = mne.filter.filter_data( + data=signal, + sfreq=sfreq, + l_freq=freq_band[0], + h_freq=freq_band[1], + l_trans_bandwidth=trans_bandwidth, + h_trans_bandwidth=trans_bandwidth, + fir_design="firwin2", + verbose=False, + ) + + # simulate noise for each channel + noise = np.random.randn(n_channels, n_epochs * n_times + connection_delay) + + # create data by projecting signal into noise + data = (signal * snr) + (noise * (1 - snr)) + + # shift target data by desired delay + if connection_delay > 0: + # shift target data + data[n_seeds:, connection_delay:] = data[n_seeds:, : n_epochs * n_times] + # remove extra time + data = data[:, : n_epochs * n_times] + + # reshape data into epochs + data = data.reshape(n_channels, n_epochs, n_times) + data = data.transpose((1, 0, 2)) # (epochs x channels x times) + + # store data in an MNE Epochs object + ch_names = [f"{ch_i}_{freq_band[0]}_{freq_band[1]}" for ch_i in range(n_channels)] + info = mne.create_info( + ch_names=ch_names, sfreq=sfreq, ch_types="eeg", verbose=False + ) + epochs = mne.EpochsArray(data=data, info=info, verbose=False) + + return epochs + + +# %% +def plot_absolute_coherency(conn_data, label): + """Plot the absolute value of coherency across frequencies""" + _, axis = plt.subplots() + axis.plot( + conn_data.freqs, np.abs(conn_data.get_data()[0]), linewidth=2, label=label + ) + axis.set_xlabel("Frequency (Hz)") + axis.set_ylabel("Absolute connectivity (A.U.)") + plt.title("CaCoh") + plt.legend(loc="upper right") + plt.show() + + +# %% +# Set parameters +n_epochs = 10 +n_times = 200 +sfreq = 100 +snr = 0.7 +freq_bands = { + "theta": [4.0, 8], + "alpha": [8.0, 12], + "beta": [12.0, 25], + "Gamma": [30.0, 45.0], +} + +n_seeds = 4 +n_targets = 3 +indices = ([np.arange(n_seeds)], [n_seeds + np.arange(n_targets)]) + +# %%we will generates synthetic EEG signals +# First we will simulate a small dataset that consists of 3 synthetic EEG sensors +# designed as seed channels and 4 synthetic EEG sensors designed as target channels +# Then we will consider two cases; one with zero- and one with non zero time-lag to +# explore the CaCoh of each frequency bin. The seed data is initially generated as +# noise. The target data are a band-pass filtered version of the seed channels. + +# %% +# Case 1: Zero time-lag interactions. +# +# In our first scenario, we explore connectivity dynamics without any temporal +# separation between the seed and target channels, setting the connectivity delay to +# zero. This configuration will allow us to investigate instantaneous interactions, +# simulating conditions where neural signals are synchronized without time lag. +delay = 0 + +# Generate simulated data +con_data = simulate_connectivity( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=freq_bands["beta"], + n_epochs=n_epochs, + n_times=n_times, + sfreq=sfreq, + snr=snr, + connection_delay=delay, + rng_seed=42, +) + +# %% +# Compute the multivariate connectivity using the CaCoh method. +con = mne_connectivity.spectral_connectivity_epochs( + con_data, indices=indices, method="cacoh" +) + +# %% +# Plot the absolute coherence value for each frequency bin. +plot_absolute_coherency(con, "Zero-lag interaction") + +# We observe a significant peak in the beta frequency band, indicating a high level of +# coherence. This suggests a strong synchronization between the seed and target +# channels within that frequency range. One might assume that such synchronization +# could be due to genuine neural connectivity. However, without phase lag, it's also +# possible that this result could stem from volume conduction or common reference +# artifacts, rather than direct physiological interactions. +# +# In the context of EEG analysis, volume conduction is a prevalent concern; yet, for +# example for LFP recordings, the spatial resolution is higher, and signals are less +# likely to be confounded by this phenomenon. Therefore, when interpreting such a peak +# in LFP data, one could be more confident that it reflects true neural interactions. + + +# %% +# Case 2: Non-zero time-lag interactions. +# +# For the exploration of non-zero time-lag interactions, we adjust the simulation to +# include a delay of 10 timepoints between the seed and target signals. This will model +# the temporal delays in neural communication. + +delay = 10 + +# Generate new simulated data +con_data = simulate_connectivity( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=freq_bands["beta"], + n_epochs=n_epochs, + n_times=n_times, + sfreq=sfreq, + snr=snr, + connection_delay=delay, + rng_seed=42, +) + +# %% +# Compute the multivariate connectivity for the new simulated data. +con = mne_connectivity.spectral_connectivity_epochs( + con_data, indices=indices, method="cacoh" +) + +# %% +# Plot the absolute coherence value for each frequency bin. +plot_absolute_coherency(con, "Non-zero-lag interaction") + +# We can see the coherence across frequencies with a notable peak also in the beta +# band, but the coherence values are overall a bit lower than in the zero-lag scenario. +# This illustrates the temporal delay introduced between seed and target signals, +# simulating a more realistic scenario where neuronal communications involve +# transmission delays (such as synaptic or axonal delays). + + +############################################################################### +# Theoretical description of canonical Coherence (CaCoh) +# +# In methematical terms, the Canonical Coherence (CaCoh) method aims to maximize the +# coherence between two signal spaces, :math:`A` and :math:`B`, each +# of dimension :math:`\(N_A\)` and :math:`\(N_B\)` respectively. In a practical +# scenario, :math:`A` might represent signals from EMG sensors, and :math:`B` from EEG +# sensors. The primary goal of CaCoh is to find real-valued linear combinations of +# signals from these spaces that maximize coherence at a specific frequency. +# +# This maximization is formulated as: +# +# :math:`\[ CaCoh = \lambda(\Phi)=\frac{\mathbf{a}^T \mathbf{D}(\Phi) \mathbf{b}}{\sqrt +# {\mathbf{a}^T \mathbf{a} \cdot \mathbf{b}^T \mathbf{b}}} \]` +# +# where :math:`\(\mathbf{D}(\Phi) = \mathbf{C}_{AA}^{-1/2} \mathbf{C}_{AB, \Phi}^R +# \mathbf{C}_{BB}^{-1/2}\)`. Here, :math:`\(\mathbf{C}_{AB, \Phi}^R\)` denotes the real +# part of the cross-spectrum, while :math:`\(\mathbf{C}_{AA}\)` and :math:`\(\mathbf{C}_ +# {BB}\)` are the auto-spectral matrices for spaces :math:`\(\alpha\)` and +# :math:`\(\beta\)`, respectively. +# +# The method inherently assumes instantaneous mixing of the signals, which justifies +# focusing on the real parts of the cross-spectral matrices. The complex Hermitian +# nature (where a square matrix is equal to its own conjugate transpose) of these +# matrices means that their imaginary components do not contribute to the maximization +# process and are thus typically set to zero. +# +# The analytical resolution of CaCoh leads to an eigenvalue problem: +# +# :math:`\[ \mathbf{D}(\Phi)^T \mathbf{D}(\Phi) \mathbf{b} = \lambda \mathbf{b} \]` +# :math:`\[ \mathbf{D}(\Phi) \mathbf{D}(\Phi)^T \mathbf{a} = \lambda \mathbf{a} \]` +# +# where :math:`\(\mathbf{a}\)` and :math:`\(\mathbf{b}\)` are the eigenvectors derived +# from the respective spaces :math:`\(\alpha\)` and :math:`\(\beta\)`. +# :math:`\(\lambda\)`, the maximal eigenvalue, represents the maximal CaCoh. The +# numerical estimation of the phase of coherence, where its absolute value is maximal, +# is achieved through a nonlinear search, emphasizing the method's robustness in +# identifying the most coherent signal combinations across different modalities. + + +############################################################################### +# Overfitting +# ----------- +# The concern regarding overfitting arises when the spatial filters :math:`\(\alpha\)` +# and :math:`\(\beta\)`, designed to maximize Canonical Coherence (CaCoh), overly adapt +# to the specific dataset, compromising their applicability to new data. This is +# particularly relevant in high-dimensional datasets. To mitigate this, dimensionality +# reduction via Singular Value Decomposition (SVD) is applied to the real part of the +# cross-spectra in the spaces \(A\) and \(B\) before computing the spatial filters +# (Eqs. 14 & 15 of [1]). This process involves selecting singular vectors that preserve +# most of the data's information, ensuring that the derived filters are both effective +# and generalizable. +# +# The dimensionality of data can be controlled using the ``rank`` parameter, which by +# default assumes data is of full rank and does not reduce dimensionality. To +# accurately reflect the data's structure and avoid bias, it's important to choose a +# rank based on the expected number of significant components. This selection helps +# standardize connectivity estimates across different recordings, even when the number +# of channels varies. Note that this does not refer to the number of seeds and targets +# within a connection being identical, rather to the number of seeds and targets across +# connections. +# +# In the following example, we will create two datasets with a larger number of seeds +# and targets. In the first dataet we apply the dimensionality reduction approach to +# only the first component in our rank subspace. We aim to compare the effects on +# connectivity patterns with the second dataset. +# +# The result indicate that essential connectivity patterns are preserved even after +# dimensionality reduction, implying that much of the connectivity information in the +# additional components may be redundant. This approach underscores the efficiency of +# focusing analysis on the most significant data dimensions. + +# %% +n_seeds = 15 +n_targets = 10 +indices = ([np.arange(n_seeds)], [n_seeds + np.arange(n_targets)]) +delay = 10 + +con_data = simulate_connectivity( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=freq_bands["beta"], + n_epochs=n_epochs, + n_times=n_times, + sfreq=sfreq, + snr=snr, + connection_delay=delay, + rng_seed=42, +) + +con_data_red = simulate_connectivity( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=freq_bands["beta"], + n_epochs=n_epochs, + n_times=n_times, + sfreq=sfreq, + snr=snr, + connection_delay=delay, + rng_seed=42, +) + +# %% +# Compute the multivariate connectivity using the CaCoh method. +con = mne_connectivity.spectral_connectivity_epochs( + con_data, indices=indices, method="cacoh" +) + +con_red = mne_connectivity.spectral_connectivity_epochs( + con_data, indices=indices, method="cacoh", rank=([1], [1]) +) + +# subtract mean of scores for comparison +con_meansub = con.get_data()[0] - con.get_data()[0].mean() +con_red_meansub = con_red.get_data()[0] - con_red.get_data()[0].mean() + +# no. channels equal with and without projecting to rank subspace for patterns +assert ( + np.array(con_red.attrs["patterns"])[0, 0].shape[0] + == np.array(con_red.attrs["patterns"])[0, 0].shape[0] +) +assert ( + np.array(con.attrs["patterns"])[1, 0].shape[0] + == np.array(con.attrs["patterns"])[1, 0].shape[0] +) + +_, axis = plt.subplots() +axis.plot(con.freqs, con_meansub, linewidth=2, label="Standard cacoh") +axis.plot( + con_red.freqs, + con_red_meansub, + linewidth=2, + label="Rank subspace (1) cacoh", +) +axis.set_xlabel("Frequency (Hz)") +axis.set_ylabel("Absolute connectivity (A.U.)") +plt.title("CaCoh") +plt.legend(loc="upper right") +plt.show() + +# In the case that your data is not full rank and rank is left as None, an automatic +# rank computation is performed and an appropriate degree of dimensionality reduction +# will be enforced. The rank of the data is determined by computing the singular values +# of the data and finding those within a factor of :math: `1e-6` relative to the +# largest singular value. + +# Whilst unlikely, there may be scenarios in which this threshold may be too lenient. +# In these cases, you should inspect the singular values of your data to identify an +# appropriate degree of dimensionality reduction to perform, which you can then specify +# manually using the ``rank`` argument. The code below shows one possible approach for +# finding an appropriate rank of close-to-singular data with a more conservative +# threshold. + +# %% +# gets the singular values of the data +s = np.linalg.svd(con.get_data(), compute_uv=False) +# finds how many singular values are 'close' to the largest singular value +rank = np.count_nonzero(s >= s[0] * 1e-4) # 1e-4 is the 'closeness' criteria +print(rank) + +############################################################################### +# Advantages and disadvantages +# +# In EEG data analysis, zero-lag interactions are typically viewed with suspicion +# because they often indicate volume conduction rather than genuine physiological +# interactions. Volume conduction is a phenomenon where electrical currents from active +# neurons spread passively through the brain tissue and skull to reach the scalp, where +# they are recorded. This can make spatially distinct but electrically active areas +# appear to be synchronously active, creating artificial coherence at zero lag. +# However, it is possible that some zero-lag interactions are real, especially if the +# neural sources are physically close to each other or if there is a common driver +# influencing multiple regions simultaneously. +# +# CaCoh, by design, does not specifically distinguish between zero-lag interactions +# that are physiological and those that are artifacts of volume conduction. Its main +# purpose is to identify patterns of maximal coherence across multiple channels or +# conditions. However, because it does not exclude zero-lag interactions, it might not +# inherently differentiate between true connectivity and volume conduction effects. +# +# Nevertheless, in the context of LFP signals, which typically represent local field +# potentials recorded from electrodes implanted in the brain, the concern for volume +# conduction is less pronounced compared to EEG because LFP signals are less influenced +# by the spread of electrical activity through the scalp and skull. In this domain, the +# CaCoh method still operates and can potentially capture true zero-lag interactions +# that are physiological in nature. The method could distinguish between true zero-lag +# interactions and those resulting from volume conduction if the spatial resolution is +# high enough to separate the sources of the signals, which is often the case with LFPs +# due to their proximity to the neural sources. +# +# On the other hand, the presence of a non-zero lag is often indicative of genuine +# physiological interactions, as it suggests a time course for signal transmission +# across neural pathways. This is especially pertinent in EEG/MEG and LFP analyses. +# CaCoh capture those interactions and help to understand the dynamics of these +# time-lagged connections. +# + +############################################################################### +# (CaCoh): Vidaurre et al. (2019). NeuroImage. DOI: 10.1016/j.neuroimage.2019.116009 +# (MIC) Ewald et al. (2012). NeuroImage. DOI: 10.1016/j.neuroimage.2011.11.084 +# TODO: cite: NolteEtAl2004 From 4652a44ba298ebac6c3fee4d04dfe4a627257ded Mon Sep 17 00:00:00 2001 From: Mohammad Date: Fri, 2 Feb 2024 16:22:34 +0100 Subject: [PATCH 28/59] (cacoh-example): consider spatial patterns, include references, link Eq. to paper --- examples/cacoh.py | 50 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/examples/cacoh.py b/examples/cacoh.py index e3785625..c6e6fe26 100644 --- a/examples/cacoh.py +++ b/examples/cacoh.py @@ -3,7 +3,7 @@ Compute multivariate measure of (absolute) coherence ==================================================== -This example showcases the application of the Canonical Coherence (CaCoh) method, as detailed :footcite: `Vidaurre et al. (2019)`, for detecting neural synchronization in multivariate signal spaces. +This example showcases the application of the Canonical Coherence (CaCoh) method, as detailed :footcite: `VidaurreEtAl2019`, for detecting neural synchronization in multivariate signal spaces. The method maximizes the absolute value of the coherence between two sets of multivariate spaces directly in the frequency domain. For each frequency bin two spatial filters are computed in order to maximize the coherence between the projected components. """ @@ -31,10 +31,10 @@ # challenges such as volume conduction and source mixing—where signals from distinct # neural sources blend or appear falsely synchronized due to the conductive properties # of brain tissue, complicate the interpretation of connectivity data. While other -# multivariate methods like `MIC` specifically exclude zero time-lag interactions to -# avoid volume conduction artifacts, assuming these to be non-physiological, the -# Canonical Coherence (Cacoh) method can capture and analyze interactions between -# signals with both zero and non-zero time-lag. +# multivariate methods like [``MIC``](https://mne.tools/mne-connectivity/stable/auto_examples/mic_mim.html) by (:footcite:`EwaldEtAl2012`) specifically exclude +# zero time-lag interactions to avoid volume conduction artifacts, assuming these to be +# non-physiological, the Canonical Coherence (Cacoh) method can capture and analyze +# interactions between signals with both zero and non-zero time-lag. # # This capability allows Cacoh to identify neural interactions that occur # simultaneously, offering insights into connectivity that may be overlooked by other @@ -301,15 +301,14 @@ def plot_absolute_coherency(conn_data, label): # sensors. The primary goal of CaCoh is to find real-valued linear combinations of # signals from these spaces that maximize coherence at a specific frequency. # -# This maximization is formulated as: +# This maximization is formulated as (Eq. of 8 in :footcite: `VidaurreEtAl2019`): # # :math:`\[ CaCoh = \lambda(\Phi)=\frac{\mathbf{a}^T \mathbf{D}(\Phi) \mathbf{b}}{\sqrt # {\mathbf{a}^T \mathbf{a} \cdot \mathbf{b}^T \mathbf{b}}} \]` # -# where :math:`\(\mathbf{D}(\Phi) = \mathbf{C}_{AA}^{-1/2} \mathbf{C}_{AB, \Phi}^R -# \mathbf{C}_{BB}^{-1/2}\)`. Here, :math:`\(\mathbf{C}_{AB, \Phi}^R\)` denotes the real -# part of the cross-spectrum, while :math:`\(\mathbf{C}_{AA}\)` and :math:`\(\mathbf{C}_ -# {BB}\)` are the auto-spectral matrices for spaces :math:`\(\alpha\)` and +# where :math:`\(\mathbf{D}(\Phi, \a, \b) = \mathbf{C}_{AA}^{-1/2} \mathbf{C}_{AB, \Phi} +# ^R \mathbf{C}_{BB}^{-1/2}\)`. Here, :math:`\(\mathbf{C}_{AB, \Phi}^R\)` denotes the +# real part of the cross-spectrum, while :math:`\(\mathbf{C}_{AA}\)` and :math:`\ # :math:`\(\beta\)`, respectively. # # The method inherently assumes instantaneous mixing of the signals, which justifies @@ -318,7 +317,8 @@ def plot_absolute_coherency(conn_data, label): # matrices means that their imaginary components do not contribute to the maximization # process and are thus typically set to zero. # -# The analytical resolution of CaCoh leads to an eigenvalue problem: +# The analytical resolution of CaCoh leads to an eigenvalue problem (Eq. 12 of +# VidaurreEtAl2019): # # :math:`\[ \mathbf{D}(\Phi)^T \mathbf{D}(\Phi) \mathbf{b} = \lambda \mathbf{b} \]` # :math:`\[ \mathbf{D}(\Phi) \mathbf{D}(\Phi)^T \mathbf{a} = \lambda \mathbf{a} \]` @@ -329,7 +329,25 @@ def plot_absolute_coherency(conn_data, label): # numerical estimation of the phase of coherence, where its absolute value is maximal, # is achieved through a nonlinear search, emphasizing the method's robustness in # identifying the most coherent signal combinations across different modalities. - +# +# To provide insights into the locations of sources influencing connectivity, spatial +# patterns can be obtained through spatial filters. To identify the topographies +# corresponding to the spatial filters :math:`\alpha` and :math:`\beta`, the filters +# are multiplied by their respective real part of the cross-spectral matrix, as follows +# (Eq. 14 of :footcite: `VidaurreEtAl2019`): + +# For :math:`\alpha`, calculate: :math:`t_{\boldsymbol{\alpha}} = \mathbf{C}_{A A}^R +# \boldsymbol{\alpha}` +# For :math:`\beta`, calculate: :math:`t_{\boldsymbol{\beta}} = \mathbf{C}_{B B}^R +# \boldsymbol{\beta}` + +# These topographies represent the patterns of the sources with maximum coherence. The +# time courses of CaCoh components directly indicate the activity of neuronal sources. +# The spatial patterns, stored under the connectivity class's 'attrs['patterns']', +# assign a frequency-specific value to each seed and target channel. For simulated +# data, our focus is on coherence analysis without visualizing spatial patterns. +# An example for the visualization for the spatial patterns can be similarly +# accomplished using a the [``MIC``](https://mne.tools/mne-connectivity/stable/auto_examples/mic_mim.html) method (:footcite:`EwaldEtAl2012`). ############################################################################### # Overfitting @@ -488,6 +506,8 @@ def plot_absolute_coherency(conn_data, label): # ############################################################################### -# (CaCoh): Vidaurre et al. (2019). NeuroImage. DOI: 10.1016/j.neuroimage.2019.116009 -# (MIC) Ewald et al. (2012). NeuroImage. DOI: 10.1016/j.neuroimage.2011.11.084 -# TODO: cite: NolteEtAl2004 +# References +# ---------- +# .. footbibliography:: + +# %% From ab3c64be57c062376b0ffd5ef15a12ffc5ced146 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 5 Feb 2024 00:28:21 +0100 Subject: [PATCH 29/59] Add reference --- doc/references.bib | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/doc/references.bib b/doc/references.bib index 14271fd5..63d270b2 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -82,6 +82,16 @@ @article{HaufeEtAl2013 year = {2013} } +@article{HaufeEtAl2014, + author={Haufe, Stefan and Meinecke, Frank and G{\"o}rgen, Kai and D{\"a}hne, Sven and Haynes, John-Dylan and Blankertz, Benjamin and Bie{\ss}mann, Felix}, + doi = {10.1016/j.neuroimage.2013.10.067}, + journal={NeuroImage}, + pages={96--110}, + title={On the interpretation of weight vectors of linear models in multivariate neuroimaging}, + volume={87}, + year={2014}, +} + @article{HippEtAl2012, author = {Hipp, Joerg F and Hawellek, David J and Corbetta, Maurizio and Siegel, Markus and Engel, Andreas K}, doi = {10.1038/nn.3101}, From ac3d3f6054e24923d142128cc10e4836109bda62 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 5 Feb 2024 00:28:52 +0100 Subject: [PATCH 30/59] Update whats_new log message --- doc/whats_new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 90365443..ada7d447 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -24,7 +24,7 @@ Version 0.7 (in dev) Enhancements ~~~~~~~~~~~~ -- Add support for a new multivariate connectivity method (canonical coherence; ``CaCoh``) in :func:`mne_connectivity.spectral_connectivity_epochs` and :func:`mne_connectivity.spectral_connectivity_time` by `Thomas Binns`_ and `Mohammad Orabe`_ and `Mina Jamshidi`_ (:pr:`163`). +- Add support for a new multivariate connectivity method (canonical coherence; ``cacoh``) in :func:`mne_connectivity.spectral_connectivity_epochs` and :func:`mne_connectivity.spectral_connectivity_time` by `Thomas Binns`_ and `Mohammad Orabe`_ and `Mina Jamshidi`_ (:pr:`163`). Bug ~~~ From 5aadfdbbdc50c1a4aadea888b213c51ff146c663 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 5 Feb 2024 00:29:08 +0100 Subject: [PATCH 31/59] Update version citation file --- CITATION.cff | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CITATION.cff b/CITATION.cff index 996e7fa1..41ea1d98 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -29,6 +29,6 @@ authors: title: "mne-connectivity" -version: 0.2.0 -date-released: 2022-01-13 +version: 0.7.0 +date-released: 2024-XX-XX url: "https://github.com/mne-tools/mne-connectivity" From d792c3ac40c4f7eb5c39ff73713da332130894c6 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 5 Feb 2024 00:29:35 +0100 Subject: [PATCH 32/59] Add preliminary CaCoh vs. MIC example --- examples/cacoh_vs_mic.py | 302 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 302 insertions(+) create mode 100644 examples/cacoh_vs_mic.py diff --git a/examples/cacoh_vs_mic.py b/examples/cacoh_vs_mic.py new file mode 100644 index 00000000..1eb97c90 --- /dev/null +++ b/examples/cacoh_vs_mic.py @@ -0,0 +1,302 @@ +""" +===================================== +Comparison of coherency-based methods +===================================== + +This example demonstrates how canonical coherency (CaCoh) +:footcite:`VidaurreEtAl2019` - a multivariate method based on coherency - can +be used to compute connectivity between whole sets of sensors, alongside +spatial patterns of the connectivity. +""" + +# Authors: Thomas S. Binns +# Mohammad Orabe +# License: BSD (3-clause) + +# %% +import numpy as np +from matplotlib import pyplot as plt + +import mne +from mne_connectivity import seed_target_indices, spectral_connectivity_epochs + +# %% + + +def simulate_connectivity( + n_seeds: int, + n_targets: int, + freq_band: tuple[int, int], + n_epochs: int, + n_times: int, + sfreq: int, + snr: float, + connection_delay: int, + rng_seed: int | None = None, +) -> np.ndarray: + """Simulates signals interacting in a given frequency band. + + Parameters + ---------- + n_seeds : int + Number of seed channels to simulate. + + n_targets : int + Number of target channels to simulate. + + freq_band : tuple of int, int + Frequency band where the connectivity should be simulated, where the + first entry corresponds to the lower frequency, and the second entry to + the higher frequency. + + n_epochs : int + Number of epochs in the simulated data. + + n_times : int + Number of timepoints each epoch of the simulated data. + + sfreq : int + Sampling frequency of the simulated data, in Hz. + + snr : float + Signal-to-noise ratio of the simulated data. + + connection_delay : + Number of timepoints for the delay of connectivity between the seeds + and targets. If > 0, the target data is a delayed form of the seed data + by this many timepoints. + + rng_seed : int | None (default None) + Seed to use for the random number generator. If `None`, no seed is + specified. + + Returns + ------- + data : numpy.ndarray + The simulated data stored in an array. The channels are arranged + according to seeds, then targets. + """ + if rng_seed is not None: + np.random.seed(rng_seed) + + n_channels = n_seeds + n_targets + trans_bandwidth = 1 # Hz + + # simulate signal source at desired frequency band + signal = np.random.randn(1, n_epochs * n_times + connection_delay) + signal = mne.filter.filter_data( + data=signal, + sfreq=sfreq, + l_freq=freq_band[0], + h_freq=freq_band[1], + l_trans_bandwidth=trans_bandwidth, + h_trans_bandwidth=trans_bandwidth, + fir_design="firwin2", + verbose=False, + ) + + # simulate noise for each channel + noise = np.random.randn(n_channels, n_epochs * n_times + connection_delay) + + # create data by projecting signal into noise + data = (signal * snr) + (noise * (1 - snr)) + + # shift target data by desired delay + if connection_delay > 0: + # shift target data + data[n_seeds:, connection_delay:] = data[n_seeds:, : n_epochs * n_times] + # remove extra time + data = data[:, : n_epochs * n_times] + + # reshape data into epochs + data = data.reshape(n_channels, n_epochs, n_times) + data = data.transpose((1, 0, 2)) # (epochs x channels x times) + + return data + + +# %% + +# Define simulation parameters +n_seeds = 3 +n_targets = 3 +n_channels = n_seeds + n_targets +n_epochs = 10 +n_times = 200 # samples +sfreq = 100 # Hz +snr = 0.7 +rng_seed = 44 + +# Generate simulated data +data_delay = simulate_connectivity( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=(10, 12), # 10-12 Hz interaction + n_epochs=n_epochs, + n_times=n_times, + sfreq=sfreq, + snr=snr, + connection_delay=2, # samples + rng_seed=42, +) + +data_no_delay = simulate_connectivity( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=(23, 25), # 23-25 Hz interaction + n_epochs=n_epochs, + n_times=n_times, + sfreq=sfreq, + snr=snr, + connection_delay=0, # samples + rng_seed=44, +) + +# Combine data into a single array +data = np.concatenate((data_delay, data_no_delay), axis=1) + +# %% + +# Generate connectivity indices +seeds = np.concatenate( + (np.arange(n_seeds), np.arange(n_channels, n_seeds + n_channels)) +) +targets = np.concatenate( + (np.arange(n_seeds, n_channels), np.arange(n_channels + n_seeds, n_channels * 2)) +) + +bivar_indices = (seeds, targets) +multivar_indices = ([seeds], [targets]) + +# Compute CaCoh & MIC +(cacoh, mic) = spectral_connectivity_epochs( + data, + method=["cacoh", "mic"], + indices=multivar_indices, + sfreq=sfreq, + fmin=3, + fmax=35, +) + +# %% + +fig, axis = plt.subplots(1, 1) +axis.plot(cacoh.freqs, np.abs(cacoh.get_data()[0]), linewidth=2, label="CaCoh") +axis.plot( + mic.freqs, np.abs(mic.get_data()[0]), linewidth=2, label="MIC", linestyle="--" +) +axis.set_xlabel("Frequency (Hz)") +axis.set_ylabel("Connectivity (A.U.)") +axis.annotate("Non-zero\ntime lag\ninteraction", xy=(13, 0.85)) +axis.annotate("Zero\ntime lag\ninteraction", xy=(27, 0.85)) +axis.legend(loc="upper left") +fig.suptitle("CaCoh vs. MIC\nNon-zero & zero time lags") + +# %% + +# Compute Coh & ImCoh +(coh, imcoh) = spectral_connectivity_epochs( + data, + method=["coh", "imcoh"], + indices=bivar_indices, + sfreq=sfreq, + fmin=3, + fmax=35, +) + +coh_mean = np.mean(coh.get_data(), axis=0) +imcoh_mean = np.mean(np.abs(imcoh.get_data()), axis=0) + +coh_mean_subbed = coh_mean - np.mean(coh_mean) +imcoh_mean_subbed = imcoh_mean - np.mean(imcoh_mean) + +fig, axis = plt.subplots(1, 1) +axis.plot(coh.freqs, coh_mean_subbed, linewidth=2, label="Coh") +axis.plot(imcoh.freqs, imcoh_mean_subbed, linewidth=2, label="ImCoh", linestyle="--") +axis.set_xlabel("Frequency (Hz)") +axis.set_ylabel("Mean-corrected connectivity (A.U.)") +axis.annotate("Non-zero\ntime lag\ninteraction", xy=(13, 0.25)) +axis.annotate("Zero\ntime lag\ninteraction", xy=(25, 0.25)) +axis.legend(loc="upper left") +fig.suptitle("Coh vs. ImCoh\nNon-zero & zero time lags") + +# %% + +# Generate simulated data +data_10_12 = simulate_connectivity( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=(10, 12), # 10-12 Hz interaction + n_epochs=n_epochs, + n_times=n_times, + sfreq=sfreq, + snr=snr, + connection_delay=1, # samples + rng_seed=42, +) + +data_23_25 = simulate_connectivity( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=(23, 25), # 10-12 Hz interaction + n_epochs=n_epochs, + n_times=n_times, + sfreq=sfreq, + snr=snr, + connection_delay=1, # samples + rng_seed=44, +) + +# Combine data into a single array +data = np.concatenate((data_10_12, data_23_25), axis=1) + +# Compute CaCoh & MIC +(cacoh, mic) = spectral_connectivity_epochs( + data, + method=["cacoh", "mic"], + indices=multivar_indices, + sfreq=sfreq, + fmin=3, + fmax=35, +) + +fig, axis = plt.subplots(1, 1) +axis.plot(cacoh.freqs, np.abs(cacoh.get_data()[0]), linewidth=2, label="CaCoh") +axis.plot( + mic.freqs, np.abs(mic.get_data()[0]), linewidth=2, label="MIC", linestyle="--" +) +axis.set_xlabel("Frequency (Hz)") +axis.set_ylabel("Connectivity (A.U.)") +axis.annotate("45°\ninteraction", xy=(12.5, 0.9)) +axis.annotate("90°\ninteraction", xy=(26.5, 0.9)) +axis.legend(loc="upper left") +fig.suptitle("CaCoh vs. MIC\n45° & 90° interactions") + +# %% + +# Compute Coh & ImCoh +(coh, imcoh) = spectral_connectivity_epochs( + data, + method=["coh", "imcoh"], + indices=bivar_indices, + sfreq=sfreq, + fmin=3, + fmax=35, +) + +coh_mean = np.mean(coh.get_data(), axis=0) +imcoh_mean = np.mean(np.abs(imcoh.get_data()), axis=0) +coh_mean_subbed = coh_mean - np.mean(coh_mean) +imcoh_mean_subbed = imcoh_mean - np.mean(imcoh_mean) + +fig, axis = plt.subplots(1, 1) +axis.plot(coh.freqs, coh_mean_subbed, linewidth=2, label="Coh") +axis.plot(imcoh.freqs, imcoh_mean_subbed, linewidth=2, label="ImCoh", linestyle="--") +axis.set_xlabel("Frequency (Hz)") +axis.set_ylabel("Mean-corrected connectivity (A.U.)") +axis.annotate("45°\ninteraction", xy=(12, 0.25)) +axis.annotate("90°\ninteraction", xy=(26.5, 0.25)) +axis.legend(loc="upper left") +fig.suptitle("Coh vs. ImCoh\n45° & 90° interactions") + +# %% From cdca613e299fe17383dc8d0eb8596de345f0f296 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 5 Feb 2024 00:30:57 +0100 Subject: [PATCH 33/59] Update MIC/MIM example --- examples/mic_mim.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/examples/mic_mim.py b/examples/mic_mim.py index 244e045b..464a46b3 100644 --- a/examples/mic_mim.py +++ b/examples/mic_mim.py @@ -52,7 +52,9 @@ # increases the signal-to-noise ratio and allows signals to be analysed in a # multivariate manner :footcite:`EwaldEtAl2012`. This approach leads to the # following methods: the maximised imaginary part of coherency (MIC); and the -# multivariate interaction measure (MIM). +# multivariate interaction measure (MIM). These methods are similar to the +# multivariate method based on coherency (CaCoh; see :doc:`cacoh`), which is +# also supported by MNE-Connectivity. # # We start by loading some example MEG data and dividing it into # two-second-long epochs. @@ -121,8 +123,9 @@ # eigendecomposition of information from the cross-spectral density (Eq. 7 of # :footcite:`EwaldEtAl2012`): # -# :math:`MIC=\frac{\boldsymbol{\alpha}^T \boldsymbol{E \beta}}{\parallel -# \boldsymbol{\alpha}\parallel \parallel\boldsymbol{\beta}\parallel}`, +# :math:`\textrm{MIC}=\frac{\boldsymbol{\alpha}^T \boldsymbol{E \beta}} +# {\parallel\boldsymbol{\alpha}\parallel \parallel\boldsymbol{\beta} +# \parallel}`, # # where :math:`\boldsymbol{\alpha}` and :math:`\boldsymbol{\beta}` are the # spatial filters for the seeds and targets, respectively, and @@ -155,13 +158,13 @@ ############################################################################### # Furthermore, spatial patterns of connectivity can be constructed from the -# spatial filters to give a picture of the location of the sources involved in -# the connectivity. This information is stored under ``attrs['patterns']`` of -# the connectivity class, with one value per frequency for each channel in the -# seeds and targets. As with MIC, the absolute value of the patterns reflect -# the strength, however the sign differences can be used to visualise the -# orientation of the underlying dipole sources. The spatial patterns are -# **not** bound between :math:`[-1, 1]`. +# spatial filters to give a picture of the location of the channels involved in +# the connectivity :footcite:`HaufeEtAl2014`. This information is stored under +# ``attrs['patterns']`` of the connectivity class, with one value per frequency +# for each channel in the seeds and targets. As with MIC, the absolute value of +# the patterns reflect the strength, however the sign differences can be used +# to visualise the orientation of the underlying dipole sources. The spatial +# patterns are **not** bound between :math:`[-1, 1]`. # # Here, we average across the patterns in the 13-18 Hz range. Plotting the # patterns shows that the greatest connectivity between the left and right @@ -242,7 +245,7 @@ # component explicitly, and instead the desired result can be achieved from # :math:`E` alone (Eq. 14 of :footcite:`EwaldEtAl2012`): # -# :math:`MIM=tr(\boldsymbol{EE}^T)`, +# :math:`\textrm{MIM}=tr(\boldsymbol{EE}^T)`, # # where again the frequency dependence is omitted. Unlike MIC, MIM is # positive-valued and can be > 1. Without normalisation, MIM can be @@ -371,6 +374,7 @@ # compare standard and rank subspace-projected MIM fig, axis = plt.subplots(1, 1) +axis.plot(mim.freqs, mim_meansub, linewidth=2, label="standard MIM") axis.plot(mim_red.freqs, mim_red_meansub, linewidth=2, label="rank subspace (25) MIM") axis.plot(mim.freqs, mim_meansub, linewidth=2, label="standard MIM") axis.set_xlabel("Frequency (Hz)") From 5cb9c449d9e3f3ec3b671d9d80b86b02a0e045cb Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 5 Feb 2024 00:31:32 +0100 Subject: [PATCH 34/59] Update docstrings --- mne_connectivity/spectral/epochs.py | 24 ++++++++++------- .../spectral/epochs_multivariate.py | 12 ++++----- mne_connectivity/spectral/time.py | 26 +++++++++++-------- 3 files changed, 35 insertions(+), 27 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 628b20cd..46195677 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -815,26 +815,27 @@ def spectral_connectivity_epochs( C = ---------------------- sqrt(E[Sxx] * E[Syy]) - 'cacoh' : Canonical Coherence (CaCoh) :footcite:`VidaurreEtAl2019` + 'cacoh' : Canonical Coherency (CaCoh) :footcite:`VidaurreEtAl2019` given by: - :math:`CaCoh=\Large{\frac{\mid\boldsymbol{a}^T\boldsymbol{D}(\Phi) - \boldsymbol{b}\mid}{\sqrt{\boldsymbol{a}^T\boldsymbol{a} + :math:`\textrm{CaCoh}=\Large{\frac{\boldsymbol{a}^T\boldsymbol{D} + (\Phi)\boldsymbol{b}}{\sqrt{\boldsymbol{a}^T\boldsymbol{a} \boldsymbol{b}^T\boldsymbol{b}}}}` where: :math:`\boldsymbol{D}(\Phi)` is the cross-spectral density between seeds and targets transformed for a given phase angle :math:`\Phi`; and :math:`\boldsymbol{a}` and :math:`\boldsymbol{b}` - are eigenvectors for the seeds and targets, such that :math:`\mid - \boldsymbol{a}^T\boldsymbol{D}(\Phi)\boldsymbol{b}\mid` maximises - coherence between the seeds and targets. + are eigenvectors for the seeds and targets, such that + :math:`\boldsymbol{a}^T\boldsymbol{D}(\Phi)\boldsymbol{b}` + maximises coherency between the seeds and targets. Taking the + absolute value of the results gives maximised coherence. 'mic' : Maximised Imaginary part of Coherency (MIC) :footcite:`EwaldEtAl2012` given by: - :math:`MIC=\Large{\frac{\boldsymbol{\alpha}^T \boldsymbol{E \beta}} - {\parallel\boldsymbol{\alpha}\parallel \parallel\boldsymbol{\beta} - \parallel}}` + :math:`\textrm{MIC}=\Large{\frac{\boldsymbol{\alpha}^T + \boldsymbol{E \beta}}{\parallel\boldsymbol{\alpha}\parallel + \parallel\boldsymbol{\beta}\parallel}}` where: :math:`\boldsymbol{E}` is the imaginary part of the transformed cross-spectral density between seeds and targets; and @@ -846,7 +847,10 @@ def spectral_connectivity_epochs( 'mim' : Multivariate Interaction Measure (MIM) :footcite:`EwaldEtAl2012` given by: - :math:`MIM=tr(\boldsymbol{EE}^T)` + :math:`\textrm{MIM}=tr(\boldsymbol{EE}^T)` + + where :math:`\boldsymbol{E}` is the imaginary part of the + transformed cross-spectral density between seeds and targets. 'plv' : Phase-Locking Value (PLV) :footcite:`LachauxEtAl1999` given by:: diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index b64415a7..d6d4f65b 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -180,11 +180,11 @@ class _MultivariateCohEstBase(_EpochMeanMultivariateConEstBase): """Base estimator for multivariate coherency methods. See: - - Imaginary part of coherency, i.e. multivariate imaginary part of coherency (MIC) - and multivariate interaction measure (MIM): Ewald et al. (2012). NeuroImage. DOI: - 10.1016/j.neuroimage.2011.11.084 - - Coherency/coherence, i.e. canonical coherence (CaCoh): Vidaurre et al. (2019). - NeuroImage. DOI: 10.1016/j.neuroimage.2019.116009 + - Imaginary part of coherency, i.e. multivariate imaginary part of + coherency (MIC) and multivariate interaction measure (MIM): Ewald et al. + (2012). NeuroImage. DOI: 10.1016/j.neuroimage.2011.11.084 + - Coherency/coherence, i.e. canonical coherency (CaCoh): Vidaurre et al. + (2019). NeuroImage. DOI: 10.1016/j.neuroimage.2019.116009 """ name: Optional[str] = None @@ -345,7 +345,7 @@ def _invsqrtm(C, T, n_seeds): frequencies. See Eq. 3 of Ewald et al. (2012). NeuroImage. DOI: - 10.1016/j.neuroimage.2011.11.084 + 10.1016/j.neuroimage.2011.11.084. """ for time_i in range(C.shape[0]): T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power( diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index a76f3ca7..8e4d0f66 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -73,7 +73,7 @@ def spectral_connectivity_time( ``['coh', 'cacoh', 'mic', 'mim', 'plv', 'ciplv', 'pli', 'wpli', 'gc', 'gc_tr']``. These are: * 'coh' : Coherence - * 'cacoh' : Canonical Coherence (CaCoh) + * 'cacoh' : Canonical Coherency (CaCoh) * 'mic' : Maximised Imaginary part of Coherency (MIC) * 'mim' : Multivariate Interaction Measure (MIM) * 'plv' : Phase-Locking Value (PLV) @@ -252,26 +252,27 @@ def spectral_connectivity_time( C = --------------------- sqrt(E[Sxx] * E[Syy]) - 'cacoh' : Canonical Coherence (CaCoh) :footcite:`VidaurreEtAl2019` + 'cacoh' : Canonical Coherency (CaCoh) :footcite:`VidaurreEtAl2019` given by: - :math:`CaCoh=\Large{\frac{\mid\boldsymbol{a}^T\boldsymbol{D}(\Phi) - \boldsymbol{b}\mid}{\sqrt{\boldsymbol{a}^T\boldsymbol{a} + :math:`\textrm{CaCoh}=\Large{\frac{\boldsymbol{a}^T\boldsymbol{D} + (\Phi)\boldsymbol{b}}{\sqrt{\boldsymbol{a}^T\boldsymbol{a} \boldsymbol{b}^T\boldsymbol{b}}}}` where: :math:`\boldsymbol{D}(\Phi)` is the cross-spectral density between seeds and targets transformed for a given phase angle :math:`\Phi`; and :math:`\boldsymbol{a}` and :math:`\boldsymbol{b}` - are eigenvectors for the seeds and targets, such that :math:`\mid - \boldsymbol{a}^T\boldsymbol{D}(\Phi)\boldsymbol{b}\mid` maximises - coherence between the seeds and targets. + are eigenvectors for the seeds and targets, such that + :math:`\boldsymbol{a}^T\boldsymbol{D}(\Phi)\boldsymbol{b}` + maximises coherency between the seeds and targets. Taking the + absolute value of the results gives maximised coherence. 'mic' : Maximised Imaginary part of Coherency (MIC) :footcite:`EwaldEtAl2012` given by: - :math:`MIC=\Large{\frac{\boldsymbol{\alpha}^T \boldsymbol{E \beta}} - {\parallel\boldsymbol{\alpha}\parallel \parallel\boldsymbol{\beta} - \parallel}}` + :math:`\textrm{MIC}=\Large{\frac{\boldsymbol{\alpha}^T + \boldsymbol{E \beta}}{\parallel\boldsymbol{\alpha}\parallel + \parallel\boldsymbol{\beta}\parallel}}` where: :math:`\boldsymbol{E}` is the imaginary part of the transformed cross-spectral density between seeds and targets; and @@ -283,7 +284,10 @@ def spectral_connectivity_time( 'mim' : Multivariate Interaction Measure (MIM) :footcite:`EwaldEtAl2012` given by: - :math:`MIM=tr(\boldsymbol{EE}^T)` + :math:`\textrm{MIM}=tr(\boldsymbol{EE}^T)` + + where :math:`\boldsymbol{E}` is the imaginary part of the + transformed cross-spectral density between seeds and targets. 'plv' : Phase-Locking Value (PLV) :footcite:`LachauxEtAl1999` given by:: From 5cb370a3869b233f6c8617cfe60b45a2699afa4b Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 5 Feb 2024 00:35:06 +0100 Subject: [PATCH 35/59] Redraft CaCoh example --- examples/cacoh.py | 699 ++++++++++++++++++++++++---------------------- 1 file changed, 366 insertions(+), 333 deletions(-) diff --git a/examples/cacoh.py b/examples/cacoh.py index c6e6fe26..932d8a99 100644 --- a/examples/cacoh.py +++ b/examples/cacoh.py @@ -1,16 +1,16 @@ """ -==================================================== -Compute multivariate measure of (absolute) coherence -==================================================== - -This example showcases the application of the Canonical Coherence (CaCoh) method, as detailed :footcite: `VidaurreEtAl2019`, for detecting neural synchronization in multivariate signal spaces. - -The method maximizes the absolute value of the coherence between two sets of multivariate spaces directly in the frequency domain. For each frequency bin two spatial filters are computed in order to maximize the coherence between the projected components. +======================================== +Compute multivariate coherency/coherence +======================================== + +This example demonstrates how canonical coherency (CaCoh) +:footcite:`VidaurreEtAl2019` - a multivariate method based on coherency - can +be used to compute connectivity between whole sets of sensors, alongside +spatial patterns of the connectivity. """ # Authors: Mohammad Orabe # Thomas S. Binns -# # License: BSD (3-clause) # %% @@ -18,61 +18,55 @@ from matplotlib import pyplot as plt import mne -import mne_connectivity +from mne_connectivity import seed_target_indices, spectral_connectivity_epochs ############################################################################### # Background # ---------- # -# Multivariate connectivity methods have emerged as a sophisticated approach to -# understand the complex interplay between multiple brain regions simultaneously. These -# methods transcend the limitations of bivariate analyses, which focus on pairwise -# interactions, by offering a holistic view of brain network dynamics. However, -# challenges such as volume conduction and source mixing—where signals from distinct -# neural sources blend or appear falsely synchronized due to the conductive properties -# of brain tissue, complicate the interpretation of connectivity data. While other -# multivariate methods like [``MIC``](https://mne.tools/mne-connectivity/stable/auto_examples/mic_mim.html) by (:footcite:`EwaldEtAl2012`) specifically exclude -# zero time-lag interactions to avoid volume conduction artifacts, assuming these to be -# non-physiological, the Canonical Coherence (Cacoh) method can capture and analyze -# interactions between signals with both zero and non-zero time-lag. +# Multivariate forms of signal analysis allow you to simultaneously consider +# the activity of multiple signals. In the case of connectivity, the +# interaction between multiple sensors can be analysed at once, producing a +# single connectivity spectrum. This approach brings not only practical +# benefits (e.g. easier interpretability of results from the dimensionality +# reduction), but can also offer methodological improvements (e.g. enhanced +# signal-to-noise ratio). # -# This capability allows Cacoh to identify neural interactions that occur -# simultaneously, offering insights into connectivity that may be overlooked by other -# methods. The Cacoh method utilizes spatial filters to isolate true neural -# interactions. This approach not only simplifies the interpretation of complex neural -# interactions but also potentially enhances the signal-to-noise ratio, offering -# methodological advantages such as reduced bias from source mixing. +# A popular bivariate measure of connectivity is coherency/coherence, which +# looks at the correlation between two signals in the frequency domain. +# However, in cases where interactions between multiple signals are of +# interest, computing connectivity between all possible combinations of signals +# leads to a very large number of results which is difficult to interpret. A +# common approach is to average results across these connections, however this +# risks reducing the signal-to-noise ratio of results and burying interactions +# that are present between only a small number of channels. # -# CaCoh maximizes the absolute value of the coherence between the two multivariate -# spaces directly in the frequency domain, where for each frequency bin, two spatial -# filters are computed in order to maximize the coherence between the projected -# components. +# Canonical coherency (CaCoh) is a multivariate form of coherency that uses +# spatial filters to extract the relevant components of connectivity in a +# frequency-resolved manner :footcite:`VidaurreEtAl2019`. It is similar to +# multivariate methods based on the imaginary part of coherency (MIC & MIM; see +# :doc:`mic_mim`), which are also supported by MNE-Connectivity. ############################################################################### # Data Simulation # --------------- # -# The CaCoh method can be used to investigate the synchronization between two -# modalities. For instance it can be applied to optimize (in a non-invasive way) the -# cortico-muscular coherence between a central set of electroencephalographic (EEG) -# sensors and a peripheral set of electromyographic (EMG) electrodes, where each -# subspaces is multivariate CaCoh is capable of taking into account the fact that -# cortico-spinal interactions are multivariate in nature not only on the cortical level -# but also at the level of the spinal cord, where multiple afferent and efferent -# processes occur. +# To demonstrate the CaCoh method, will we use some simulated data consisting +# of an interaction between signals in a given frequency range. Here, we +# simulate two sets of interactions: # -# CaCoh extends beyond analyzing cortico-muscular interactions. It is also applicable -# in various EEG/MEG/LFP research scenarios, such as studying the interactions between -# cortical (EEG/MEG) and subcortical activities, or in examining intracortical local -# field potentials (LFP). +# - 5 seeds and 3 targets interacting in the 10-12 Hz frequency range. +# - 5 seeds and 3 targets interacting in the 23-25 Hz frequency range. # -# In this demo script, we will generates synthetic EEG signals. Let's define a function -# that enables the study of both zero-lag and non-zero-lag interactions by adjusting -# the parameter `connection_delay`. - +# We can consider the seeds and targets to be signals of different modalities, +# e.g. cortical EEG signals and subcortical local field potential signals, +# cortical EEG signals and muscular EMG signals, etc.... We use the function +# below to simulate these signals. # %% + + def simulate_connectivity( n_seeds: int, n_targets: int, @@ -81,9 +75,9 @@ def simulate_connectivity( n_times: int, sfreq: int, snr: float, - connection_delay, + connection_delay: int, rng_seed: int | None = None, -) -> mne.Epochs: +) -> np.ndarray: """Simulates signals interacting in a given frequency band. Parameters @@ -95,8 +89,9 @@ def simulate_connectivity( Number of target channels to simulate. freq_band : tuple of int, int - Frequency band where the connectivity should be simulated, where the first entry corresponds - to the lower frequency, and the second entry to the higher frequency. + Frequency band where the connectivity should be simulated, where the + first entry corresponds to the lower frequency, and the second entry to + the higher frequency. n_epochs : int Number of epochs in the simulated data. @@ -111,17 +106,19 @@ def simulate_connectivity( Signal-to-noise ratio of the simulated data. connection_delay : - Number of timepoints for the delay of connectivity between the seeds and targets. If > 0, - the target data is a delayed form of the seed data by this many timepoints. + Number of timepoints for the delay of connectivity between the seeds + and targets. If > 0, the target data is a delayed form of the seed data + by this many timepoints. rng_seed : int | None (default None) - Seed to use for the random number generator. If `None`, no seed is specified. + Seed to use for the random number generator. If `None`, no seed is + specified. Returns ------- - epochs : mne.Epochs - The simulated data stored in an Epochs object. The channels are arranged according to seeds, - then targets. + data : numpy.ndarray + The simulated data stored in an array. The channels are arranged + according to seeds, then targets. """ if rng_seed is not None: np.random.seed(rng_seed) @@ -159,351 +156,387 @@ def simulate_connectivity( data = data.reshape(n_channels, n_epochs, n_times) data = data.transpose((1, 0, 2)) # (epochs x channels x times) - # store data in an MNE Epochs object - ch_names = [f"{ch_i}_{freq_band[0]}_{freq_band[1]}" for ch_i in range(n_channels)] - info = mne.create_info( - ch_names=ch_names, sfreq=sfreq, ch_types="eeg", verbose=False - ) - epochs = mne.EpochsArray(data=data, info=info, verbose=False) + return data - return epochs +############################################################################### # %% -def plot_absolute_coherency(conn_data, label): - """Plot the absolute value of coherency across frequencies""" - _, axis = plt.subplots() - axis.plot( - conn_data.freqs, np.abs(conn_data.get_data()[0]), linewidth=2, label=label - ) - axis.set_xlabel("Frequency (Hz)") - axis.set_ylabel("Absolute connectivity (A.U.)") - plt.title("CaCoh") - plt.legend(loc="upper right") - plt.show() - -# %% -# Set parameters +# Define simulation parameters +n_seeds = 5 +n_targets = 3 +n_channels = n_seeds + n_targets n_epochs = 10 -n_times = 200 -sfreq = 100 +n_times = 200 # samples +sfreq = 100 # Hz snr = 0.7 -freq_bands = { - "theta": [4.0, 8], - "alpha": [8.0, 12], - "beta": [12.0, 25], - "Gamma": [30.0, 45.0], -} - -n_seeds = 4 -n_targets = 3 -indices = ([np.arange(n_seeds)], [n_seeds + np.arange(n_targets)]) - -# %%we will generates synthetic EEG signals -# First we will simulate a small dataset that consists of 3 synthetic EEG sensors -# designed as seed channels and 4 synthetic EEG sensors designed as target channels -# Then we will consider two cases; one with zero- and one with non zero time-lag to -# explore the CaCoh of each frequency bin. The seed data is initially generated as -# noise. The target data are a band-pass filtered version of the seed channels. - -# %% -# Case 1: Zero time-lag interactions. -# -# In our first scenario, we explore connectivity dynamics without any temporal -# separation between the seed and target channels, setting the connectivity delay to -# zero. This configuration will allow us to investigate instantaneous interactions, -# simulating conditions where neural signals are synchronized without time lag. -delay = 0 +connection_delay = 10 # samples +rng_seed = 44 # Generate simulated data -con_data = simulate_connectivity( +data_10_12 = simulate_connectivity( n_seeds=n_seeds, n_targets=n_targets, - freq_band=freq_bands["beta"], + freq_band=(10, 12), # 10-12 Hz interaction n_epochs=n_epochs, n_times=n_times, sfreq=sfreq, snr=snr, - connection_delay=delay, + connection_delay=connection_delay, rng_seed=42, ) -# %% -# Compute the multivariate connectivity using the CaCoh method. -con = mne_connectivity.spectral_connectivity_epochs( - con_data, indices=indices, method="cacoh" -) - -# %% -# Plot the absolute coherence value for each frequency bin. -plot_absolute_coherency(con, "Zero-lag interaction") - -# We observe a significant peak in the beta frequency band, indicating a high level of -# coherence. This suggests a strong synchronization between the seed and target -# channels within that frequency range. One might assume that such synchronization -# could be due to genuine neural connectivity. However, without phase lag, it's also -# possible that this result could stem from volume conduction or common reference -# artifacts, rather than direct physiological interactions. -# -# In the context of EEG analysis, volume conduction is a prevalent concern; yet, for -# example for LFP recordings, the spatial resolution is higher, and signals are less -# likely to be confounded by this phenomenon. Therefore, when interpreting such a peak -# in LFP data, one could be more confident that it reflects true neural interactions. - - -# %% -# Case 2: Non-zero time-lag interactions. -# -# For the exploration of non-zero time-lag interactions, we adjust the simulation to -# include a delay of 10 timepoints between the seed and target signals. This will model -# the temporal delays in neural communication. - -delay = 10 - -# Generate new simulated data -con_data = simulate_connectivity( +data_23_25 = simulate_connectivity( n_seeds=n_seeds, n_targets=n_targets, - freq_band=freq_bands["beta"], + freq_band=(23, 25), # 23-25 Hz interaction n_epochs=n_epochs, n_times=n_times, sfreq=sfreq, snr=snr, - connection_delay=delay, - rng_seed=42, + connection_delay=connection_delay, + rng_seed=44, ) +# Combine data into a single array +data = np.concatenate((data_10_12, data_23_25), axis=1) + +############################################################################### +# Computing CaCoh +# --------------- +# +# Having simulated the signals, we can create the indices for computing +# connectivity between all seeds and all targets in a single multivariate +# connection (see :doc:`handling_ragged_arrays` for more information), after +# which we compute connectivity. +# +# For CaCoh, a set of spatial filters are found that will maximise the +# estimated connectivity between the seed and target signals. These maximising +# filters correspond to the eigenvectors with the largest eigenvalue, derived +# from an eigendecomposition of information from the cross-spectral density +# (Eq. 8 of :footcite:`VidaurreEtAl2019`): +# +# :math:`\textrm{CaCoh}=\Large{\frac{\boldsymbol{a}^T\boldsymbol{D}(\Phi) +# \boldsymbol{b}}{\sqrt{\boldsymbol{a}^T\boldsymbol{a}\boldsymbol{b}^T +# \boldsymbol{b}}}}` +# +# where: :math:`\boldsymbol{D}(\Phi)` is the cross-spectral density between +# seeds and targets transformed for a given phase angle :math:`\Phi`; and +# :math:`\boldsymbol{a}` and :math:`\boldsymbol{b}` are eigenvectors for the +# seeds and targets, such that :math:`\boldsymbol{a}^T\boldsymbol{D}(\Phi) +# \boldsymbol{b}` maximises coherency between the seeds and targets. All +# elements are frequency-dependent, however this is omitted for readability. +# +# CaCoh is complex-valued in the range :math:`[-1, 1]` where the sign reflects +# the phase angle of the interaction (akin to coherency). Taking the absolute +# value is akin to taking the coherence, which is the magnitude of the +# interaction regardless of phase angle. + # %% -# Compute the multivariate connectivity for the new simulated data. -con = mne_connectivity.spectral_connectivity_epochs( - con_data, indices=indices, method="cacoh" + +# Generate connectivity indices +seeds = np.concatenate( + (np.arange(n_seeds), np.arange(n_channels, n_seeds + n_channels)) +) +targets = np.concatenate( + (np.arange(n_seeds, n_channels), np.arange(n_channels + n_seeds, n_channels * 2)) +) + +multivar_indices = ([seeds], [targets]) + +# Compute CaCoh +cacoh = spectral_connectivity_epochs( + data, method="cacoh", indices=multivar_indices, sfreq=sfreq, fmin=3, fmax=35 ) +############################################################################### +# As you can see below, using CaCoh we have summarised the most relevant +# connectivity information from our 10 seed channels and 6 target channels as a +# single spectrum of connectivity values. This lower-dimensional representation +# of signal interactions is much more interpretable when analysing connectivity +# in complex systems such as the brain. + # %% -# Plot the absolute coherence value for each frequency bin. -plot_absolute_coherency(con, "Non-zero-lag interaction") -# We can see the coherence across frequencies with a notable peak also in the beta -# band, but the coherence values are overall a bit lower than in the zero-lag scenario. -# This illustrates the temporal delay introduced between seed and target signals, -# simulating a more realistic scenario where neuronal communications involve -# transmission delays (such as synaptic or axonal delays). +print(f"Results shape: {cacoh.get_data().shape} (connections x frequencies)") +fig, axis = plt.subplots(1, 1) +axis.plot(cacoh.freqs, np.abs(cacoh.get_data()[0]), linewidth=2) +axis.set_xlabel("Frequency (Hz)") +axis.set_ylabel("Connectivity (A.U.)") +fig.suptitle("CaCoh") ############################################################################### -# Theoretical description of canonical Coherence (CaCoh) -# -# In methematical terms, the Canonical Coherence (CaCoh) method aims to maximize the -# coherence between two signal spaces, :math:`A` and :math:`B`, each -# of dimension :math:`\(N_A\)` and :math:`\(N_B\)` respectively. In a practical -# scenario, :math:`A` might represent signals from EMG sensors, and :math:`B` from EEG -# sensors. The primary goal of CaCoh is to find real-valued linear combinations of -# signals from these spaces that maximize coherence at a specific frequency. -# -# This maximization is formulated as (Eq. of 8 in :footcite: `VidaurreEtAl2019`): -# -# :math:`\[ CaCoh = \lambda(\Phi)=\frac{\mathbf{a}^T \mathbf{D}(\Phi) \mathbf{b}}{\sqrt -# {\mathbf{a}^T \mathbf{a} \cdot \mathbf{b}^T \mathbf{b}}} \]` -# -# where :math:`\(\mathbf{D}(\Phi, \a, \b) = \mathbf{C}_{AA}^{-1/2} \mathbf{C}_{AB, \Phi} -# ^R \mathbf{C}_{BB}^{-1/2}\)`. Here, :math:`\(\mathbf{C}_{AB, \Phi}^R\)` denotes the -# real part of the cross-spectrum, while :math:`\(\mathbf{C}_{AA}\)` and :math:`\ -# :math:`\(\beta\)`, respectively. -# -# The method inherently assumes instantaneous mixing of the signals, which justifies -# focusing on the real parts of the cross-spectral matrices. The complex Hermitian -# nature (where a square matrix is equal to its own conjugate transpose) of these -# matrices means that their imaginary components do not contribute to the maximization -# process and are thus typically set to zero. -# -# The analytical resolution of CaCoh leads to an eigenvalue problem (Eq. 12 of -# VidaurreEtAl2019): -# -# :math:`\[ \mathbf{D}(\Phi)^T \mathbf{D}(\Phi) \mathbf{b} = \lambda \mathbf{b} \]` -# :math:`\[ \mathbf{D}(\Phi) \mathbf{D}(\Phi)^T \mathbf{a} = \lambda \mathbf{a} \]` +# Note that we plot the absolute values of the results (coherence) rather than +# the complex values (coherency). The absolute value of connectivity will +# generally be of most interest, however information such as the phase of +# interaction can only be extracted from the complex-valued results, e.g. with +# the :func:`numpy.angle` function. + +# %% + +fig, axis = plt.subplots(1, 1) +axis.plot(cacoh.freqs, np.angle(cacoh.get_data()[0]), linewidth=2) +axis.set_xlabel("Frequency (Hz)") +axis.set_ylabel("Phase of connectivity (radians)") +fig.suptitle("CaCoh") + +############################################################################### +# CaCoh versus coherence +# ---------------------- # -# where :math:`\(\mathbf{a}\)` and :math:`\(\mathbf{b}\)` are the eigenvectors derived -# from the respective spaces :math:`\(\alpha\)` and :math:`\(\beta\)`. -# :math:`\(\lambda\)`, the maximal eigenvalue, represents the maximal CaCoh. The -# numerical estimation of the phase of coherence, where its absolute value is maximal, -# is achieved through a nonlinear search, emphasizing the method's robustness in -# identifying the most coherent signal combinations across different modalities. +# To further demonstrate the signal-to-noise ratio benefits of CaCoh, below we +# compute connectivity between each seed and target using bivariate coherence. +# With our 10 seeds and 6 targets, this gives us a total of 60 unique +# connections which is very difficult to interpret without aggregating some +# information. A common approach is to simply average across these connections, +# which we do below. + +# %% + +# Define bivariate connectivity indices +bivar_indices = seed_target_indices(seeds=seeds, targets=targets) + +# Compute bivariate coherence +coh = spectral_connectivity_epochs( + data, method="coh", indices=bivar_indices, sfreq=sfreq, fmin=3, fmax=35 +) + +############################################################################### +# Plotting the bivariate and multivariate results together, we can see that +# coherence still captures the interactions at 10-12 Hz and 23-25 Hz, however +# the scale of the connectivity is much smaller. This reflects the fact that +# CaCoh is able to capture the relevant components of interactions between +# multiple signals, regardless of whether they are present in all channels. # -# To provide insights into the locations of sources influencing connectivity, spatial -# patterns can be obtained through spatial filters. To identify the topographies -# corresponding to the spatial filters :math:`\alpha` and :math:`\beta`, the filters -# are multiplied by their respective real part of the cross-spectral matrix, as follows -# (Eq. 14 of :footcite: `VidaurreEtAl2019`): - -# For :math:`\alpha`, calculate: :math:`t_{\boldsymbol{\alpha}} = \mathbf{C}_{A A}^R -# \boldsymbol{\alpha}` -# For :math:`\beta`, calculate: :math:`t_{\boldsymbol{\beta}} = \mathbf{C}_{B B}^R -# \boldsymbol{\beta}` - -# These topographies represent the patterns of the sources with maximum coherence. The -# time courses of CaCoh components directly indicate the activity of neuronal sources. -# The spatial patterns, stored under the connectivity class's 'attrs['patterns']', -# assign a frequency-specific value to each seed and target channel. For simulated -# data, our focus is on coherence analysis without visualizing spatial patterns. -# An example for the visualization for the spatial patterns can be similarly -# accomplished using a the [``MIC``](https://mne.tools/mne-connectivity/stable/auto_examples/mic_mim.html) method (:footcite:`EwaldEtAl2012`). +# The ability of multivariate connectivity methods to capture the underlying +# components of connectivity is extremely useful when dealing with data from +# a large number of channels, with inter-channel interactions at distinct +# frequencies, a problem explored in more detail in the :doc:`mic_mim` example. + +# %% + +print(f"Results shape: {coh.get_data().shape} (connections x frequencies)") + +cacoh_min = np.min(np.abs(cacoh.get_data()[0])) +coh_min = np.min(np.mean(coh.get_data(), axis=0)) + +fig, axis = plt.subplots(1, 1) +axis.plot( + coh.freqs, np.abs(cacoh.get_data()[0]) - cacoh_min, linewidth=2, label="CaCoh" +) +axis.plot( + coh.freqs, np.mean(coh.get_data(), axis=0) - coh_min, linewidth=2, label="Coh" +) +axis.set_xlabel("Frequency (Hz)") +axis.set_ylabel("Baseline-corrected connectivity (A.U.)") +axis.legend() +fig.suptitle("CaCoh vs. coherence") ############################################################################### -# Overfitting -# ----------- -# The concern regarding overfitting arises when the spatial filters :math:`\(\alpha\)` -# and :math:`\(\beta\)`, designed to maximize Canonical Coherence (CaCoh), overly adapt -# to the specific dataset, compromising their applicability to new data. This is -# particularly relevant in high-dimensional datasets. To mitigate this, dimensionality -# reduction via Singular Value Decomposition (SVD) is applied to the real part of the -# cross-spectra in the spaces \(A\) and \(B\) before computing the spatial filters -# (Eqs. 14 & 15 of [1]). This process involves selecting singular vectors that preserve -# most of the data's information, ensuring that the derived filters are both effective -# and generalizable. +# Extracting spatial information from CaCoh +# ----------------------------------------- # -# The dimensionality of data can be controlled using the ``rank`` parameter, which by -# default assumes data is of full rank and does not reduce dimensionality. To -# accurately reflect the data's structure and avoid bias, it's important to choose a -# rank based on the expected number of significant components. This selection helps -# standardize connectivity estimates across different recordings, even when the number -# of channels varies. Note that this does not refer to the number of seeds and targets -# within a connection being identical, rather to the number of seeds and targets across -# connections. +# Whilst a lower-dimensional representation of connectivity information is +# useful, we lose information about which channels are involved in the +# connectivity. Thankfully, this information can be recovered by constructing +# spatial patterns of connectivity from the spatial filters +# :footcite:`HaufeEtAl2014`. # -# In the following example, we will create two datasets with a larger number of seeds -# and targets. In the first dataet we apply the dimensionality reduction approach to -# only the first component in our rank subspace. We aim to compare the effects on -# connectivity patterns with the second dataset. +# The spatial patterns are stored under ``attrs['patterns']`` of the +# connectivity class, with one value per frequency for each channel in the +# seeds and targets. The patterns can be positive- and negative-valued. Sign +# differences of the patterns can be used to visualise the orientation of +# underlying dipole sources, whereas their absolute value reflects the strength +# of a channel's contribution to the connectivity component. The spatial +# patterns are **not** bound between :math:`[-1, 1]`. # -# The result indicate that essential connectivity patterns are preserved even after -# dimensionality reduction, implying that much of the connectivity information in the -# additional components may be redundant. This approach underscores the efficiency of -# focusing analysis on the most significant data dimensions. +# Averaging across the patterns in the 10-12 Hz and 23-25 Hz ranges, we can see +# how it is possible to identify which channels are contributing to +# connectivity at different frequencies. # %% -n_seeds = 15 -n_targets = 10 -indices = ([np.arange(n_seeds)], [n_seeds + np.arange(n_targets)]) -delay = 10 -con_data = simulate_connectivity( - n_seeds=n_seeds, - n_targets=n_targets, - freq_band=freq_bands["beta"], - n_epochs=n_epochs, - n_times=n_times, - sfreq=sfreq, - snr=snr, - connection_delay=delay, - rng_seed=42, -) +freqs = cacoh.freqs +fbands = ((10, 12), ((23, 25))) -con_data_red = simulate_connectivity( - n_seeds=n_seeds, - n_targets=n_targets, - freq_band=freq_bands["beta"], - n_epochs=n_epochs, - n_times=n_times, - sfreq=sfreq, - snr=snr, - connection_delay=delay, - rng_seed=42, -) +fig, axes = plt.subplots(1, 2) -# %% -# Compute the multivariate connectivity using the CaCoh method. -con = mne_connectivity.spectral_connectivity_epochs( - con_data, indices=indices, method="cacoh" -) +# patterns have shape [seeds/targets x cons x channels x freqs (x times)] +patterns = np.abs(np.array(cacoh.attrs["patterns"])) +seed_pattern = patterns[0, :, : len(seeds)] +target_pattern = patterns[1, :, : len(targets)] -con_red = mne_connectivity.spectral_connectivity_epochs( - con_data, indices=indices, method="cacoh", rank=([1], [1]) -) +vmin = np.nanmin(patterns) +vmax = np.nanmax(patterns) + +for axis, fband in zip(axes, fbands): + # average across frequencies + seed_pattern_fband = np.mean( + seed_pattern[0, :, freqs.index(fband[0]) : freqs.index(fband[1]) + 1], axis=1 + ) + target_pattern_fband = np.mean( + target_pattern[0, :, freqs.index(fband[0]) : freqs.index(fband[1]) + 1], axis=1 + ) -# subtract mean of scores for comparison -con_meansub = con.get_data()[0] - con.get_data()[0].mean() -con_red_meansub = con_red.get_data()[0] - con_red.get_data()[0].mean() + # combine into a single array + pattern_fband = np.concatenate((seed_pattern_fband, target_pattern_fband), axis=0) -# no. channels equal with and without projecting to rank subspace for patterns -assert ( - np.array(con_red.attrs["patterns"])[0, 0].shape[0] - == np.array(con_red.attrs["patterns"])[0, 0].shape[0] + # plot the pattern + mesh = axis.pcolormesh( + np.flip(np.expand_dims(pattern_fband, 1)), vmin=vmin, vmax=vmax + ) + axis.set_yticks([1.5, 4.5, 8.5, 13.5]) + axis.set_xticks([0.5]) + axis.set_xticklabels([f"{fband[0]}-{fband[1]}"]) + +# Label axes +fig.suptitle("Spatial patterns") +axes[0].set_yticklabels( + [ + "Targets\n(23-25 Hz)", + "Targets\n(10-12 Hz)", + "Seeds\n(23-25 Hz)", + "Seeds\n(10-12 Hz)", + ], + rotation=45, + va="center", ) -assert ( - np.array(con.attrs["patterns"])[1, 0].shape[0] - == np.array(con.attrs["patterns"])[1, 0].shape[0] +axes[0].set_ylabel("Channels") +axes[1].get_yaxis().set_visible(False) +fig.text(0.47, 0.02, "Frequency band (Hz)", ha="center") + +# Set colourbar +fig.subplots_adjust(right=0.8) +cbar_axis = fig.add_axes([0.85, 0.15, 0.02, 0.7]) +fig.colorbar(mesh, cax=cbar_axis) +cbar_axis.set_ylabel("Contribution to connectivity (A.U.)") +cbar_axis.set_yticks([vmin, vmax]) +cbar_axis.set_yticklabels(["Low", "High"]) + +plt.show() + +############################################################################### +# For an example on interpreting spatial filters with real data, see the +# :doc:`mic_mim` example. + +############################################################################### +# Handling high-dimensional data +# ------------------------------ +# +# An important issue to consider when using these multivariate methods is +# overfitting, which risks biasing connectivity estimates to maximise noise in +# the data. This risk can be reduced by performing a preliminary dimensionality +# reduction prior to estimating the connectivity with a singular value +# decomposition (Eq. 15 of :footcite:`VidaurreEtAl2019`). The degree of this +# dimensionality reduction can be specified using the ``rank`` argument, which +# by default will not perform any dimensionality reduction (assuming your data +# is full rank; see below if not). Choosing an expected rank of the data +# requires *a priori* knowledge about the number of components you expect to +# observe in the data. +# +# When comparing CaCoh scores across recordings, **it is highly recommended +# to estimate connectivity from the same number of channels (or equally from +# the same degree of rank subspace projection)** to avoid biases in +# connectivity estimates. Bias can be avoided by specifying a consistent rank +# subspace to project to using the ``rank`` argument, standardising your +# connectivity estimates regardless of changes in e.g. the number of channels +# across recordings. Note that this does not refer to the number of seeds and +# targets *within* a connection being identical, rather to the number of seeds +# and targets *across* connections. +# +# Here, we project our seed and target data to only the first 2 components of +# our rank subspace. Results show that the general spectral pattern of +# connectivity is retained in the rank subspace-projected data, suggesting that +# a fair degree of redundant connectivity information is contained in the +# excluded components of the seed and target data. +# +# We also assert that the spatial patterns of MIC are returned in the original +# sensor space despite this rank subspace projection, being reconstructed using +# the products of the singular value decomposition (Eqs. 46 & 47 of +# :footcite:`EwaldEtAl2012`). + +# %% + +# Compute CaCoh following rank subspace projection +cacoh_red = spectral_connectivity_epochs( + data, + method="cacoh", + indices=multivar_indices, + sfreq=sfreq, + fmin=3, + fmax=35, + rank=([2], [2]), ) -_, axis = plt.subplots() -axis.plot(con.freqs, con_meansub, linewidth=2, label="Standard cacoh") +# compare standard and rank subspace-projected CaCoh +fig, axis = plt.subplots(1, 1) +axis.plot(cacoh.freqs, np.abs(cacoh.get_data()[0]), linewidth=2, label="standard CaCoh") axis.plot( - con_red.freqs, - con_red_meansub, + cacoh_red.freqs, + np.abs(cacoh_red.get_data()[0]), linewidth=2, - label="Rank subspace (1) cacoh", + label="rank subspace (2) CaCoh", ) axis.set_xlabel("Frequency (Hz)") -axis.set_ylabel("Absolute connectivity (A.U.)") -plt.title("CaCoh") -plt.legend(loc="upper right") -plt.show() +axis.set_ylabel("Connectivity (A.U.)") +axis.legend() +fig.suptitle("CaCoh") -# In the case that your data is not full rank and rank is left as None, an automatic -# rank computation is performed and an appropriate degree of dimensionality reduction -# will be enforced. The rank of the data is determined by computing the singular values -# of the data and finding those within a factor of :math: `1e-6` relative to the -# largest singular value. +# no. channels equal with and without projecting to rank subspace for patterns +assert patterns[0, 0].shape[0] == np.array(cacoh_red.attrs["patterns"])[0, 0].shape[0] +assert patterns[1, 0].shape[0] == np.array(cacoh_red.attrs["patterns"])[1, 0].shape[0] -# Whilst unlikely, there may be scenarios in which this threshold may be too lenient. -# In these cases, you should inspect the singular values of your data to identify an -# appropriate degree of dimensionality reduction to perform, which you can then specify -# manually using the ``rank`` argument. The code below shows one possible approach for -# finding an appropriate rank of close-to-singular data with a more conservative -# threshold. +############################################################################### +# See :doc:`mic_mim` for an example of applying the rank subspace +# projection to real data with a large number of channels. +# +# In the case that your data is not full rank and ``rank`` is left as ``None``, +# an automatic rank computation is performed and an appropriate degree of +# dimensionality reduction will be enforced. The rank of the data is determined +# by computing the singular values of the data and finding those within a +# factor of :math:`1e^{-6}` relative to the largest singular value. +# +# Whilst unlikely, there may be scenarios in which this threshold may be too +# lenient. In these cases, you should inspect the singular values of your data +# to identify an appropriate degree of dimensionality reduction to perform, +# which you can then specify manually using the ``rank`` argument. The code +# below shows one possible approach for finding an appropriate rank of +# close-to-singular data with a more conservative threshold. # %% -# gets the singular values of the data -s = np.linalg.svd(con.get_data(), compute_uv=False) + +# gets the singular values of the data across epochs +s = np.linalg.svd(data, compute_uv=False).min(axis=0) # finds how many singular values are 'close' to the largest singular value rank = np.count_nonzero(s >= s[0] * 1e-4) # 1e-4 is the 'closeness' criteria -print(rank) ############################################################################### -# Advantages and disadvantages -# -# In EEG data analysis, zero-lag interactions are typically viewed with suspicion -# because they often indicate volume conduction rather than genuine physiological -# interactions. Volume conduction is a phenomenon where electrical currents from active -# neurons spread passively through the brain tissue and skull to reach the scalp, where -# they are recorded. This can make spatially distinct but electrically active areas -# appear to be synchronously active, creating artificial coherence at zero lag. -# However, it is possible that some zero-lag interactions are real, especially if the -# neural sources are physically close to each other or if there is a common driver -# influencing multiple regions simultaneously. +# Limitations +# ----------- # -# CaCoh, by design, does not specifically distinguish between zero-lag interactions -# that are physiological and those that are artifacts of volume conduction. Its main -# purpose is to identify patterns of maximal coherence across multiple channels or -# conditions. However, because it does not exclude zero-lag interactions, it might not -# inherently differentiate between true connectivity and volume conduction effects. +# Multivariate methods offer many benefits in the form of dimensionality +# reduction and signal-to-noise ratio improvements, however no method is +# perfect. When we simulated the data, we mentioned how we considered the seeds +# and targets to be signals of different modalities. This is an important +# factor in whether CaCoh should be used over methods based solely on the +# imaginary part of coherency such as MIC and MIM. # -# Nevertheless, in the context of LFP signals, which typically represent local field -# potentials recorded from electrodes implanted in the brain, the concern for volume -# conduction is less pronounced compared to EEG because LFP signals are less influenced -# by the spread of electrical activity through the scalp and skull. In this domain, the -# CaCoh method still operates and can potentially capture true zero-lag interactions -# that are physiological in nature. The method could distinguish between true zero-lag -# interactions and those resulting from volume conduction if the spatial resolution is -# high enough to separate the sources of the signals, which is often the case with LFPs -# due to their proximity to the neural sources. +# In short, if you want to examine connectivity between signals from the same +# modality, you should consider not using CaCoh. Instead, methods based on the +# imaginary part of coherency such as MIC and MIM should be used to avoid +# spurious connectivity estimates stemming from e.g. volume conduction +# artefacts. # -# On the other hand, the presence of a non-zero lag is often indicative of genuine -# physiological interactions, as it suggests a time course for signal transmission -# across neural pathways. This is especially pertinent in EEG/MEG and LFP analyses. -# CaCoh capture those interactions and help to understand the dynamics of these -# time-lagged connections. +# On the other hand, if you want to examine connectivity between signals from +# different modalities, CaCoh is a more appropriate method than MIC/MIM. This +# is because voilume conduction artefacts are not a concern, and CaCoh does not +# risk biasing connectivity estimates towards interactions with particular +# phase lags like MIC/MIM. # +# These scenarios are described in more detail in the :doc:`cacoh_vs_mic` +# example. ############################################################################### # References From 0c90cfa11cf1d7933ef26fb46035ca6dbfe444a0 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 5 Feb 2024 16:36:06 +0100 Subject: [PATCH 36/59] Update CaCoh example --- examples/cacoh.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/cacoh.py b/examples/cacoh.py index 932d8a99..3f353904 100644 --- a/examples/cacoh.py +++ b/examples/cacoh.py @@ -60,9 +60,9 @@ # - 5 seeds and 3 targets interacting in the 23-25 Hz frequency range. # # We can consider the seeds and targets to be signals of different modalities, -# e.g. cortical EEG signals and subcortical local field potential signals, -# cortical EEG signals and muscular EMG signals, etc.... We use the function -# below to simulate these signals. +# e.g. cortical EEG signals and subcortical LFP signals, cortical EEG signals +# and muscular EMG signals, etc.... We use the function below to simulate these +# signals. # %% @@ -531,7 +531,7 @@ def simulate_connectivity( # # On the other hand, if you want to examine connectivity between signals from # different modalities, CaCoh is a more appropriate method than MIC/MIM. This -# is because voilume conduction artefacts are not a concern, and CaCoh does not +# is because volume conduction artefacts are not a concern, and CaCoh does not # risk biasing connectivity estimates towards interactions with particular # phase lags like MIC/MIM. # From e0d6cd215d211fb765d695e75a21ffd0a213c708 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 6 Feb 2024 19:14:32 +0100 Subject: [PATCH 37/59] Update multivariate examples --- doc/references.bib | 10 + examples/cacoh.py | 121 +++------- examples/cacoh_vs_mic.py | 485 +++++++++++++++++++++++++++++---------- examples/mic_mim.py | 40 ++-- 4 files changed, 431 insertions(+), 225 deletions(-) diff --git a/doc/references.bib b/doc/references.bib index 63d270b2..473e605a 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -278,6 +278,16 @@ @article{VinckEtAl2015 year = {2015} } +@article{ViriyopaseEtAl2012, + author={Viriyopase, Atthaphon and Bojak, Ingo and Zeitler, Magteld and Gielen, Stan}, + doi={10.3389/fncom.2012.00049}, + journal={Frontiers in Computational Neuroscience}, + pages={49}, + title={When long-range zero-lag synchronization is feasible in cortical networks}, + volume={6}, + year={2012} +} + @article{Whittle1963, author = {Whittle, Peter}, doi = {10.1093/biomet/50.1-2.129}, diff --git a/examples/cacoh.py b/examples/cacoh.py index 3f353904..c1394850 100644 --- a/examples/cacoh.py +++ b/examples/cacoh.py @@ -44,8 +44,9 @@ # Canonical coherency (CaCoh) is a multivariate form of coherency that uses # spatial filters to extract the relevant components of connectivity in a # frequency-resolved manner :footcite:`VidaurreEtAl2019`. It is similar to -# multivariate methods based on the imaginary part of coherency (MIC & MIM; see -# :doc:`mic_mim`), which are also supported by MNE-Connectivity. +# multivariate methods based on the imaginary part of coherency (MIC & MIM +# :footcite:`EwaldEtAl2012`; see :doc:`mic_mim`), which are also supported by +# MNE-Connectivity. ############################################################################### @@ -67,52 +68,18 @@ # %% -def simulate_connectivity( - n_seeds: int, - n_targets: int, - freq_band: tuple[int, int], - n_epochs: int, - n_times: int, - sfreq: int, - snr: float, - connection_delay: int, - rng_seed: int | None = None, -) -> np.ndarray: +def simulate_connectivity(freq_band: tuple[int, int], rng_seed: int) -> np.ndarray: """Simulates signals interacting in a given frequency band. Parameters ---------- - n_seeds : int - Number of seed channels to simulate. - - n_targets : int - Number of target channels to simulate. - freq_band : tuple of int, int Frequency band where the connectivity should be simulated, where the first entry corresponds to the lower frequency, and the second entry to the higher frequency. - n_epochs : int - Number of epochs in the simulated data. - - n_times : int - Number of timepoints each epoch of the simulated data. - - sfreq : int - Sampling frequency of the simulated data, in Hz. - - snr : float - Signal-to-noise ratio of the simulated data. - - connection_delay : - Number of timepoints for the delay of connectivity between the seeds - and targets. If > 0, the target data is a delayed form of the seed data - by this many timepoints. - - rng_seed : int | None (default None) - Seed to use for the random number generator. If `None`, no seed is - specified. + rng_seed : int + Seed to use for the random number generator. Returns ------- @@ -120,11 +87,19 @@ def simulate_connectivity( The simulated data stored in an array. The channels are arranged according to seeds, then targets. """ - if rng_seed is not None: - np.random.seed(rng_seed) + # Define fixed simulation parameters + n_seeds = 5 + n_targets = 3 + n_epochs = 10 + n_times = 200 # samples + sfreq = 100 # Hz + snr = 0.7 + trans_bandwidth = 1 # Hz + connection_delay = 1 # sample + + np.random.seed(rng_seed) n_channels = n_seeds + n_targets - trans_bandwidth = 1 # Hz # simulate signal source at desired frequency band signal = np.random.randn(1, n_epochs * n_times + connection_delay) @@ -163,39 +138,14 @@ def simulate_connectivity( # %% -# Define simulation parameters -n_seeds = 5 -n_targets = 3 -n_channels = n_seeds + n_targets -n_epochs = 10 -n_times = 200 # samples -sfreq = 100 # Hz -snr = 0.7 -connection_delay = 10 # samples -rng_seed = 44 - # Generate simulated data data_10_12 = simulate_connectivity( - n_seeds=n_seeds, - n_targets=n_targets, freq_band=(10, 12), # 10-12 Hz interaction - n_epochs=n_epochs, - n_times=n_times, - sfreq=sfreq, - snr=snr, - connection_delay=connection_delay, rng_seed=42, ) data_23_25 = simulate_connectivity( - n_seeds=n_seeds, - n_targets=n_targets, freq_band=(23, 25), # 23-25 Hz interaction - n_epochs=n_epochs, - n_times=n_times, - sfreq=sfreq, - snr=snr, - connection_delay=connection_delay, rng_seed=44, ) @@ -236,18 +186,13 @@ def simulate_connectivity( # %% # Generate connectivity indices -seeds = np.concatenate( - (np.arange(n_seeds), np.arange(n_channels, n_seeds + n_channels)) -) -targets = np.concatenate( - (np.arange(n_seeds, n_channels), np.arange(n_channels + n_seeds, n_channels * 2)) -) - +seeds = [0, 1, 2, 3, 4, 8, 9, 10, 11, 12] +targets = [5, 6, 7, 13, 14, 15] multivar_indices = ([seeds], [targets]) # Compute CaCoh cacoh = spectral_connectivity_epochs( - data, method="cacoh", indices=multivar_indices, sfreq=sfreq, fmin=3, fmax=35 + data, method="cacoh", indices=multivar_indices, sfreq=100, fmin=3, fmax=35 ) ############################################################################### @@ -261,6 +206,7 @@ def simulate_connectivity( print(f"Results shape: {cacoh.get_data().shape} (connections x frequencies)") +# Plot CaCoh fig, axis = plt.subplots(1, 1) axis.plot(cacoh.freqs, np.abs(cacoh.get_data()[0]), linewidth=2) axis.set_xlabel("Frequency (Hz)") @@ -276,6 +222,7 @@ def simulate_connectivity( # %% +# Plot phase of connectivity fig, axis = plt.subplots(1, 1) axis.plot(cacoh.freqs, np.angle(cacoh.get_data()[0]), linewidth=2) axis.set_xlabel("Frequency (Hz)") @@ -296,11 +243,11 @@ def simulate_connectivity( # %% # Define bivariate connectivity indices -bivar_indices = seed_target_indices(seeds=seeds, targets=targets) +bivar_indices = seed_target_indices(seeds, targets) # Compute bivariate coherence coh = spectral_connectivity_epochs( - data, method="coh", indices=bivar_indices, sfreq=sfreq, fmin=3, fmax=35 + data, method="coh", indices=bivar_indices, sfreq=100, fmin=3, fmax=35 ) ############################################################################### @@ -309,11 +256,6 @@ def simulate_connectivity( # the scale of the connectivity is much smaller. This reflects the fact that # CaCoh is able to capture the relevant components of interactions between # multiple signals, regardless of whether they are present in all channels. -# -# The ability of multivariate connectivity methods to capture the underlying -# components of connectivity is extremely useful when dealing with data from -# a large number of channels, with inter-channel interactions at distinct -# frequencies, a problem explored in more detail in the :doc:`mic_mim` example. # %% @@ -322,6 +264,7 @@ def simulate_connectivity( cacoh_min = np.min(np.abs(cacoh.get_data()[0])) coh_min = np.min(np.mean(coh.get_data(), axis=0)) +# Plot CaCoh & Coh fig, axis = plt.subplots(1, 1) axis.plot( coh.freqs, np.abs(cacoh.get_data()[0]) - cacoh_min, linewidth=2, label="CaCoh" @@ -334,6 +277,12 @@ def simulate_connectivity( axis.legend() fig.suptitle("CaCoh vs. coherence") +############################################################################### +# The ability of multivariate connectivity methods to capture the underlying +# components of connectivity is extremely useful when dealing with data from +# a large number of channels, with inter-channel interactions at distinct +# frequencies, a problem explored in more detail in the :doc:`mic_mim` example. + ############################################################################### # Extracting spatial information from CaCoh # ----------------------------------------- @@ -464,7 +413,7 @@ def simulate_connectivity( data, method="cacoh", indices=multivar_indices, - sfreq=sfreq, + sfreq=100, fmin=3, fmax=35, rank=([2], [2]), @@ -481,7 +430,7 @@ def simulate_connectivity( ) axis.set_xlabel("Frequency (Hz)") axis.set_ylabel("Connectivity (A.U.)") -axis.legend() +axis.legend(loc="lower right") fig.suptitle("CaCoh") # no. channels equal with and without projecting to rank subspace for patterns @@ -531,8 +480,8 @@ def simulate_connectivity( # # On the other hand, if you want to examine connectivity between signals from # different modalities, CaCoh is a more appropriate method than MIC/MIM. This -# is because volume conduction artefacts are not a concern, and CaCoh does not -# risk biasing connectivity estimates towards interactions with particular +# is because volume conduction artefacts are of less concern, and CaCoh does +# not risk biasing connectivity estimates towards interactions with particular # phase lags like MIC/MIM. # # These scenarios are described in more detail in the :doc:`cacoh_vs_mic` diff --git a/examples/cacoh_vs_mic.py b/examples/cacoh_vs_mic.py index 1eb97c90..8bf901bc 100644 --- a/examples/cacoh_vs_mic.py +++ b/examples/cacoh_vs_mic.py @@ -3,10 +3,9 @@ Comparison of coherency-based methods ===================================== -This example demonstrates how canonical coherency (CaCoh) -:footcite:`VidaurreEtAl2019` - a multivariate method based on coherency - can -be used to compute connectivity between whole sets of sensors, alongside -spatial patterns of the connectivity. +This example demonstrates the distinct forms of information captured by +coherency-based connectivity methods, and highlights the scenarios in which +these different methods should be applied. """ # Authors: Thomas S. Binns @@ -20,55 +19,89 @@ import mne from mne_connectivity import seed_target_indices, spectral_connectivity_epochs +############################################################################### +# An introduction to coherency-based connectivity methods +# ------------------------------------------------------- +# +# MNE-Connectivity supports several methods based on coherency. These are: +# +# - coherency (Cohy) +# - coherence (Coh; absolute coherency) +# - imaginary part of coherency (ImCoh) +# - canonical coherency (CaCoh) +# - maximised imaginary part of coherency (MIC) +# - multivariate interaction measure (MIM; based on ImCoh) +# +# | +# +# All of these methods centre on Cohy, a complex-valued estimate of +# of connectivity between signals in the frequency domain. +# +# The common approach for handling these complex-valued coherency scores is to +# either take their absolute values (Coh) or their imaginary values (ImCoh +# :footcite:`NolteEtAl2004`). +# +# In addition to these traditional bivariate connectivity measures, advanced +# multivariate measures have also been developed based on Cohy (CaCoh +# :footcite:`VidaurreEtAl2019`; can take the absolute value for a multivariate +# form of Coh) or ImCoh (MIC & MIM :footcite:`EwaldEtAl2012`). +# +# Despite their similarities, there are distinct scenarios in which these +# different methods are most appropriate, as we will show in this example. + +############################################################################### +# Zero and non-zero time-lag interactions +# --------------------------------------- +# +# The key difference between Cohy/Coh and ImCoh is how information about zero +# time-lag interactions is captured. +# +# We generally assume that communication within the brain involves some delay +# in the flow of information (i.e. a non-zero time-lag). This reflects the time +# taken for: the propagation of action potentials along axons; the release of +# neurotransmitters from presynaptic terminals and binding to receptors on +# postsynaptic terminals; etc... +# +# In contrast, interactions with no delay (i.e. a zero time-lag) are often +# considered to reflect non-physiological activity, such as volume conduction +# - the propagation of electrical activity through the brain's conductive +# tissue from a single source to multiple electrodes simultaneously +# :footcite:`NolteEtAl2004`. Such interactions therefore do not reflect +# genuine, physiological communication between brain regions. Naturally, +# having a method that can discard spurious zero time-lag connectivity +# estimates is very desirable. +# +# **N.B.** Not all zero time-lag interactions are necessarily non-physiological +# :footcite:`ViriyopaseEtAl2012`. +# +# To demonstrate the differences in how Cohy/Coh and ImCoh handle zero time-lag +# interactions, we simulate two sets of data with: +# +# 1. A non-zero time-lag interaction at 10-12 Hz. +# 2. A zero time-lag interaction at 23-25 Hz. + # %% def simulate_connectivity( - n_seeds: int, - n_targets: int, - freq_band: tuple[int, int], - n_epochs: int, - n_times: int, - sfreq: int, - snr: float, - connection_delay: int, - rng_seed: int | None = None, + freq_band: tuple[int, int], connection_delay: int, rng_seed: int ) -> np.ndarray: """Simulates signals interacting in a given frequency band. Parameters ---------- - n_seeds : int - Number of seed channels to simulate. - - n_targets : int - Number of target channels to simulate. - freq_band : tuple of int, int Frequency band where the connectivity should be simulated, where the first entry corresponds to the lower frequency, and the second entry to the higher frequency. - n_epochs : int - Number of epochs in the simulated data. - - n_times : int - Number of timepoints each epoch of the simulated data. - - sfreq : int - Sampling frequency of the simulated data, in Hz. - - snr : float - Signal-to-noise ratio of the simulated data. - connection_delay : Number of timepoints for the delay of connectivity between the seeds and targets. If > 0, the target data is a delayed form of the seed data by this many timepoints. - rng_seed : int | None (default None) - Seed to use for the random number generator. If `None`, no seed is - specified. + rng_seed : int + Seed to use for the random number generator. Returns ------- @@ -76,11 +109,18 @@ def simulate_connectivity( The simulated data stored in an array. The channels are arranged according to seeds, then targets. """ - if rng_seed is not None: - np.random.seed(rng_seed) + # Define fixed simulation parameters + n_seeds = 3 + n_targets = 3 + n_epochs = 10 + n_times = 200 # samples + sfreq = 100 # Hz + snr = 0.7 + trans_bandwidth = 1 # Hz + + np.random.seed(rng_seed) n_channels = n_seeds + n_targets - trans_bandwidth = 1 # Hz # simulate signal source at desired frequency band signal = np.random.randn(1, n_epochs * n_times + connection_delay) @@ -117,69 +157,51 @@ def simulate_connectivity( # %% -# Define simulation parameters -n_seeds = 3 -n_targets = 3 -n_channels = n_seeds + n_targets -n_epochs = 10 -n_times = 200 # samples -sfreq = 100 # Hz -snr = 0.7 -rng_seed = 44 - # Generate simulated data data_delay = simulate_connectivity( - n_seeds=n_seeds, - n_targets=n_targets, freq_band=(10, 12), # 10-12 Hz interaction - n_epochs=n_epochs, - n_times=n_times, - sfreq=sfreq, - snr=snr, - connection_delay=2, # samples + connection_delay=2, # samples; non-zero time-lag rng_seed=42, ) data_no_delay = simulate_connectivity( - n_seeds=n_seeds, - n_targets=n_targets, freq_band=(23, 25), # 23-25 Hz interaction - n_epochs=n_epochs, - n_times=n_times, - sfreq=sfreq, - snr=snr, - connection_delay=0, # samples + connection_delay=0, # samples; zero time-lag rng_seed=44, ) # Combine data into a single array data = np.concatenate((data_delay, data_no_delay), axis=1) +############################################################################### +# We compute the connectivity of these simulated signals using CaCoh (a +# multivariate form of Cohy/Coh) and MIC (a multivariate form of ImCoh). + # %% -# Generate connectivity indices -seeds = np.concatenate( - (np.arange(n_seeds), np.arange(n_channels, n_seeds + n_channels)) -) -targets = np.concatenate( - (np.arange(n_seeds, n_channels), np.arange(n_channels + n_seeds, n_channels * 2)) -) +n_seeds = 3 +n_targets = 3 +n_channels = n_seeds + n_targets -bivar_indices = (seeds, targets) +# Generate connectivity indices +seeds = [0, 1, 2, 6, 7, 8] +targets = [3, 4, 5, 9, 10, 11] +bivar_indices = seed_target_indices(seeds, targets) multivar_indices = ([seeds], [targets]) # Compute CaCoh & MIC (cacoh, mic) = spectral_connectivity_epochs( - data, - method=["cacoh", "mic"], - indices=multivar_indices, - sfreq=sfreq, - fmin=3, - fmax=35, + data, method=["cacoh", "mic"], indices=multivar_indices, sfreq=100, fmin=3, fmax=35 ) +############################################################################### +# As you can see, both CaCoh and MIC capture the non-zero time-lag interaction +# at 10-12 Hz, however only CaCoh captures the zero time-lag interaction at +# 23-25 Hz. + # %% +# Plot CaCoh & MIC fig, axis = plt.subplots(1, 1) axis.plot(cacoh.freqs, np.abs(cacoh.get_data()[0]), linewidth=2, label="CaCoh") axis.plot( @@ -187,21 +209,126 @@ def simulate_connectivity( ) axis.set_xlabel("Frequency (Hz)") axis.set_ylabel("Connectivity (A.U.)") -axis.annotate("Non-zero\ntime lag\ninteraction", xy=(13, 0.85)) -axis.annotate("Zero\ntime lag\ninteraction", xy=(27, 0.85)) +axis.annotate("Non-zero\ntime-lag\ninteraction", xy=(13.5, 0.85)) +axis.annotate("Zero\ntime-lag\ninteraction", xy=(26.5, 0.85)) axis.legend(loc="upper left") -fig.suptitle("CaCoh vs. MIC\nNon-zero & zero time lags") +fig.suptitle("CaCoh vs. MIC\nNon-zero & zero time-lags") + + +# %% + + +def plot_connectivity_circle(): + """Plot a circle with radius 1, real and imag. axes, and angles marked.""" + fig, axis = plt.subplots(1, 1) + t = np.linspace(0, 2 * np.pi, 100) + axis.plot(np.cos(t), np.sin(t), color="k", linewidth=0.1) + axis.plot([-1, 1], [0, 0], color="k", linestyle="--") + axis.plot([0, 0], [-1, 1], color="k", linestyle="--") + axis.axis("off") + + fontdict = {"fontsize": 10} + qpi = np.pi / 4 + axis.text(1, 0, " 0°", ha="left", va="center", fontdict=fontdict) + axis.text(np.pi / 4, np.pi / 4, "45°", ha="center", va="center", fontdict=fontdict) + axis.text(0, 1, "90°", ha="center", va="bottom", fontdict=fontdict) + axis.text(-qpi, qpi, "135°", ha="center", va="center", fontdict=fontdict) + axis.text(-1, 0, "180°", ha="right", va="center", fontdict=fontdict) + axis.text(-qpi, -qpi, "-135°", ha="center", va="center", fontdict=fontdict) + axis.text(0, -1, "-90°", ha="center", va="top", fontdict=fontdict) + axis.text(qpi, -qpi, "-45°", ha="center", va="center", fontdict=fontdict) + + fontdict = {"fontsize": 12} + axis.text(1.15, 0, " Real", ha="left", va="center", fontdict=fontdict) + axis.text(0, 1.15, "Imaginary", ha="center", va="bottom", fontdict=fontdict) + axis.text(0, 0, "0 ", ha="right", va="top", fontdict=fontdict) + axis.text(-1, 0, "-1", ha="left", va="top", fontdict=fontdict) + axis.text(1, 0, "+1", ha="right", va="top", fontdict=fontdict) + axis.text(0, -1, "-1 ", ha="right", va="bottom", fontdict=fontdict) + axis.text(0, 1, "+1 ", ha="right", va="top", fontdict=fontdict) + + axis.set_aspect("equal") + + return fig, axis + + +############################################################################### +# The different interactions (not) captured by CaCoh and MIC can be understood +# by visualising the complex values of the interactions. + +# %% + +# Get complex connectivity values at frequency bands +freqs = cacoh.freqs +cacoh_10_12 = np.mean(cacoh.get_data()[0, freqs.index(10) : freqs.index(12) + 1]) +cacoh_23_25 = np.mean(cacoh.get_data()[0, freqs.index(23) : freqs.index(25) + 1]) + +# Plot complex connectivity values +fig, axis = plot_connectivity_circle() +axis.quiver( + 0, + 0, + np.real(cacoh_10_12), + np.imag(cacoh_10_12), + units="xy", + scale=1, + linewidth=2, + color="C2", + label="10-12 Hz", +) +axis.quiver( + 0, + 0, + np.real(cacoh_23_25), + np.imag(cacoh_23_25), + units="xy", + scale=1, + linewidth=2, + color="C3", + label="23-25 Hz", + zorder=99, +) +axis.legend(loc="upper right", bbox_to_anchor=[1.1, 1.1]) + +############################################################################### +# Above, we plot the complex-valued CaCoh scores for the 10-12 Hz and 23-25 Hz +# interactions as vectors with origin :math:`(0, 0)` bound within a circle of +# radius 1 (reflecting the fact that coherency scores span the set of complex +# values in the range :math:`[-1, 1]`). +# +# The circumference of the circle spans the range :math:`(-\pi, \pi]`. The real +# axis corresponds to vectors with angles of 0° (:math:`0\pi`; positive +# values) or 180° (:math:`\pi`; negative values). The imaginary axis +# corresponds to vectors with angles of 90° (:math:`\frac{1}{2}\pi`; positive +# values) or -90° (:math:`-\frac{1}{2}\pi`; negative values). +# +# Zero time-lag interactions have angles of 0° and 180° (i.e. no phase +# difference), corresponding to a non-zero real component, but a zero-valued +# imaginary component. We see this nicely for the 23-25 Hz interaction, which +# has an angle of ~0°. Taking the absolute CaCoh value shows us the magnitude +# of this interaction to be ~0.9. However, first projecting this information to +# the imaginary axis gives us a magnitude of ~0. +# +# In contrast, non-zero time-lag interactions do not lie on the real axis (i.e. +# a phase difference), corresponding to non-zero real and imaginary components. +# We see this nicely for the 10-12 Hz interaction, which has an angle of ~-75°. +# Taking the absolute CaCoh value shows us the magnitude of this interaction to +# be ~0.9, which is also seen when first projecting this information to the +# imaginary axis. +# +# This distinction is why connectivity methods utilising information from both +# real and imaginary components (Cohy, Coh, CaCoh) capture both zero and +# non-zero time-lag interactions, whereas methods using only the imaginary +# component (ImCoh, MIC, MIM) capture only non-zero time-lag interactions. +# +# The ability to capture these different interactions is not a feature specific +# to multivariate connectivity methods, as shown below for Coh and ImCoh. # %% # Compute Coh & ImCoh (coh, imcoh) = spectral_connectivity_epochs( - data, - method=["coh", "imcoh"], - indices=bivar_indices, - sfreq=sfreq, - fmin=3, - fmax=35, + data, method=["coh", "imcoh"], indices=bivar_indices, sfreq=100, fmin=3, fmax=35 ) coh_mean = np.mean(coh.get_data(), axis=0) @@ -210,39 +337,89 @@ def simulate_connectivity( coh_mean_subbed = coh_mean - np.mean(coh_mean) imcoh_mean_subbed = imcoh_mean - np.mean(imcoh_mean) +# Plot Coh & ImCoh fig, axis = plt.subplots(1, 1) axis.plot(coh.freqs, coh_mean_subbed, linewidth=2, label="Coh") axis.plot(imcoh.freqs, imcoh_mean_subbed, linewidth=2, label="ImCoh", linestyle="--") axis.set_xlabel("Frequency (Hz)") axis.set_ylabel("Mean-corrected connectivity (A.U.)") -axis.annotate("Non-zero\ntime lag\ninteraction", xy=(13, 0.25)) -axis.annotate("Zero\ntime lag\ninteraction", xy=(25, 0.25)) +axis.annotate("Non-zero\ntime-lag\ninteraction", xy=(13, 0.25)) +axis.annotate("Zero\ntime-lag\ninteraction", xy=(25, 0.25)) axis.legend(loc="upper left") -fig.suptitle("Coh vs. ImCoh\nNon-zero & zero time lags") +fig.suptitle("Coh vs. ImCoh\nNon-zero & zero time-lags") + +############################################################################### +# When different coherency-based methods are most appropriate +# ----------------------------------------------------------- +# +# With this information, we can define situations under which these different +# approaches are most appropriate. +# +# | +# +# **In situations where non-physiological zero time-lag interactions are +# assumed, methods based on only the imaginary part of coherency (ImCoh, MIC, +# MIM) should be used.** Examples of situations include: +# +# - Connectivity between channels of a single modality where volume conduction +# is expected, e.g. connectivity between EEG channels. +# - Connectivity between channels of different modalities where a common +# reference is used, e.g. connectivity between EEG and subcortical LFP using +# the same LFP reference. +# +# | +# +# **In situations where non-physiological zero time-lag interactions are not +# assumed, methods based on real and imaginary parts of coherency (Cohy, Coh, +# CaCoh) should be used.** Examples of situations include: +# +# - Connectivity between channels of a single modality where volume conduction +# is not expected, e.g. connectivity between ECoG channels. +# - Connectivity between channels of different modalities where different +# references are used, e.g. connectivity between EEG and subcortical LFP +# using cortical and subcortical references, respectively. +# +# | +# +# Although methods based on only the imaginary part of coherency should be used +# when non-physiological zero time-lag interactions are present, these methods +# should equally be avoided when such non-physiological interactions are +# absent. There are 2 key reasons: +# +# **1. Discarding physiological zero time-lag interactions** +# +# First, not all zero time-lag interactions are non-physiological +# :footcite:`ViriyopaseEtAl2012`. Accordingly, methods based on only the +# imaginary part of coherency may lead to information about genuine +# connectivity being lost. +# +# In situations where non-physiological zero time-lag +# interactions are present, the potential loss of physiological information is +# generally acceptable to avoid spurious connectivity estimates. However, +# unnecessarily discarding this information can be detrimental. +# +# **2. Biasing interactions based on the angle of interaction** +# +# Depending on their angles, two non-zero time-lag interactions may have the +# same magnitude in the complex plane, but different magnitudes when projected +# to the imaginary axis. +# +# This is demonstrated below, where we simulate 2 interactions with non-zero +# time-lags at 10-12 Hz and 23-25 Hz. Computing the connectivity, we see how +# both interactions have a similar magnitude (~0.9), but different angles +# (~-45° for 10-12 Hz; ~-90° for 23-25 Hz). # %% # Generate simulated data data_10_12 = simulate_connectivity( - n_seeds=n_seeds, - n_targets=n_targets, freq_band=(10, 12), # 10-12 Hz interaction - n_epochs=n_epochs, - n_times=n_times, - sfreq=sfreq, - snr=snr, connection_delay=1, # samples rng_seed=42, ) data_23_25 = simulate_connectivity( - n_seeds=n_seeds, - n_targets=n_targets, freq_band=(23, 25), # 10-12 Hz interaction - n_epochs=n_epochs, - n_times=n_times, - sfreq=sfreq, - snr=snr, connection_delay=1, # samples rng_seed=44, ) @@ -252,14 +429,54 @@ def simulate_connectivity( # Compute CaCoh & MIC (cacoh, mic) = spectral_connectivity_epochs( - data, - method=["cacoh", "mic"], - indices=multivar_indices, - sfreq=sfreq, - fmin=3, - fmax=35, + data, method=["cacoh", "mic"], indices=multivar_indices, sfreq=100, fmin=3, fmax=35 ) +# Get complex connectivity values at frequency bands +freqs = cacoh.freqs +cacoh_10_12 = np.mean(cacoh.get_data()[0, freqs.index(10) : freqs.index(12) + 1]) +cacoh_23_25 = np.mean(cacoh.get_data()[0, freqs.index(23) : freqs.index(25) + 1]) + +# Plot complex connectivity values +fig, axis = plot_connectivity_circle() +axis.quiver( + 0, + 0, + np.real(cacoh_10_12), + np.imag(cacoh_10_12), + units="xy", + scale=1, + linewidth=2, + color="C2", + label="10-12 Hz", +) +axis.quiver( + 0, + 0, + np.real(cacoh_23_25), + np.imag(cacoh_23_25), + units="xy", + scale=1, + linewidth=2, + color="C3", + label="23-25 Hz", + zorder=99, +) +axis.legend(loc="upper right", bbox_to_anchor=[1.1, 1.1]) + +############################################################################### +# Plotting the connectivity values for CaCoh and MIC, we see how the 10-12 Hz +# and 23-25 Hz interactions only have a similar magnitude for CaCoh, whereas +# the MIC scores for the 10-12 Hz interaction are lower than for the 23-25 Hz +# interaction. +# +# This difference reflects the fact that as the angle of interaction deviates +# from :math:`\pm` 90°, less information will be represented in the imaginary +# part of coherency. + +# %% + +# Plot CaCoh & MIC fig, axis = plt.subplots(1, 1) axis.plot(cacoh.freqs, np.abs(cacoh.get_data()[0]), linewidth=2, label="CaCoh") axis.plot( @@ -267,36 +484,70 @@ def simulate_connectivity( ) axis.set_xlabel("Frequency (Hz)") axis.set_ylabel("Connectivity (A.U.)") -axis.annotate("45°\ninteraction", xy=(12.5, 0.9)) -axis.annotate("90°\ninteraction", xy=(26.5, 0.9)) +axis.annotate("$\pm$45°\ninteraction", xy=(12.5, 0.9)) +axis.annotate("$\pm$90°\ninteraction", xy=(26.5, 0.9)) axis.legend(loc="upper left") -fig.suptitle("CaCoh vs. MIC\n45° & 90° interactions") +fig.suptitle("CaCoh vs. MIC\n$\pm$45° & $\pm$90° interactions") + +############################################################################### +# Accordingly, considering only the imaginary part of coherency can bias +# connectivity estimates based on the proximity of the phase angle of +# interactions to :math:`\pm` 90°, with closer angles leading to higher +# estimates of connectivity. +# +# Again, in situations where non-physiological zero time-lag interactions are +# present, this phase angle-dependent bias is generally acceptable to avoid +# spurious connectivity estimates. However, such a bias in situations where +# non-physiological zero time-lag interactions are not present is clearly +# problematic. +# +# | +# +# Again, these considerations are not specific to multivariate methods, as +# shown below with Coh and ImCoh. # %% # Compute Coh & ImCoh (coh, imcoh) = spectral_connectivity_epochs( - data, - method=["coh", "imcoh"], - indices=bivar_indices, - sfreq=sfreq, - fmin=3, - fmax=35, + data, method=["coh", "imcoh"], indices=bivar_indices, sfreq=100, fmin=3, fmax=35 ) coh_mean = np.mean(coh.get_data(), axis=0) imcoh_mean = np.mean(np.abs(imcoh.get_data()), axis=0) + coh_mean_subbed = coh_mean - np.mean(coh_mean) imcoh_mean_subbed = imcoh_mean - np.mean(imcoh_mean) +# Plot Coh & ImCoh fig, axis = plt.subplots(1, 1) axis.plot(coh.freqs, coh_mean_subbed, linewidth=2, label="Coh") axis.plot(imcoh.freqs, imcoh_mean_subbed, linewidth=2, label="ImCoh", linestyle="--") axis.set_xlabel("Frequency (Hz)") axis.set_ylabel("Mean-corrected connectivity (A.U.)") -axis.annotate("45°\ninteraction", xy=(12, 0.25)) -axis.annotate("90°\ninteraction", xy=(26.5, 0.25)) +axis.annotate("$\pm$45°\ninteraction", xy=(12, 0.25)) +axis.annotate("$\pm$90°\ninteraction", xy=(26.5, 0.25)) axis.legend(loc="upper left") -fig.suptitle("Coh vs. ImCoh\n45° & 90° interactions") +fig.suptitle("Coh vs. ImCoh\n$\pm$45° & $\pm$90° interactions") + +############################################################################### +# Conclusion +# ---------- +# +# Altogether, there are clear scenarious in which different coherency-based +# methods are appropriate. +# +# Methods based on the imaginary part of coherency alone (ImCoh, MIC, MIM) +# should only be used when non-physiological zero time-lag interactions (e.g. +# volume conduction) are present. +# +# Methods based on the real and imaginary parts of coherency (Cohy, Coh, CaCoh) +# should only be used when non-physiological zero time-lag interactions are not +# present. + +############################################################################### +# References +# ---------- +# .. footbibliography:: # %% diff --git a/examples/mic_mim.py b/examples/mic_mim.py index 464a46b3..2d9b4bac 100644 --- a/examples/mic_mim.py +++ b/examples/mic_mim.py @@ -53,8 +53,8 @@ # multivariate manner :footcite:`EwaldEtAl2012`. This approach leads to the # following methods: the maximised imaginary part of coherency (MIC); and the # multivariate interaction measure (MIM). These methods are similar to the -# multivariate method based on coherency (CaCoh; see :doc:`cacoh`), which is -# also supported by MNE-Connectivity. +# multivariate method based on coherency (CaCoh :footcite:`VidaurreEtAl2019`; +# see :doc:`cacoh`), which is also supported by MNE-Connectivity. # # We start by loading some example MEG data and dividing it into # two-second-long epochs. @@ -376,7 +376,6 @@ fig, axis = plt.subplots(1, 1) axis.plot(mim.freqs, mim_meansub, linewidth=2, label="standard MIM") axis.plot(mim_red.freqs, mim_red_meansub, linewidth=2, label="rank subspace (25) MIM") -axis.plot(mim.freqs, mim_meansub, linewidth=2, label="standard MIM") axis.set_xlabel("Frequency (Hz)") axis.set_ylabel("Mean-corrected connectivity (A.U.)") axis.legend() @@ -415,27 +414,24 @@ # # These multivariate methods offer many benefits in the form of dimensionality # reduction, signal-to-noise ratio improvements, and invariance to -# estimate-biasing source mixing; however, no method is perfect. The immunity -# of the imaginary part of coherency to volume conduction comes from the fact -# that these artefacts have zero phase lag, and hence a zero-valued imaginary -# component. By projecting the complex-valued coherency to the imaginary axis, -# signals of a given magnitude with phase lag differences close to 90° or 270° -# see their contributions to the connectivity estimate increased relative to -# comparable signals with phase lag differences close to 0° or 180°. Therefore, -# the imaginary part of coherency is biased towards connectivity involving 90° -# and 270° phase lag difference components. +# estimate-biasing source mixing; however, no method is perfect. Important +# considerations must be taken into account when choosing methods based on the +# imaginary part of coherency such as MIC or MIM versus those based on +# coherency/coherence, such as CaCoh. +# +# In short, if you want to examine connectivity between signals from the same +# modality, you should consider using MIC and MIM to avoid spurious +# connectivity estimates stemming from e.g. volume conduction artefacts. # -# Whilst this is not a limitation specific to the multivariate extension of -# this measure, these multivariate methods can introduce further bias: when -# maximising the imaginary part of coherency, components with phase lag -# differences close to 90° and 270° will likely give higher connectivity -# estimates, and so may be prioritised by the spatial filters. +# On the other hand, if you want to examine connectivity between signals from +# different modalities, CaCoh is a more appropriate method than MIC/MIM. This +# is because voilume conduction artefacts are of less concern, and CaCoh does +# not risk biasing connectivity estimates towards interactions with particular +# phase lags like MIC/MIM. # -# Such a limitation should be kept in mind when estimating connectivity using -# these methods. Possible sanity checks can involve comparing the spectral -# profiles of MIC/MIM to coherence and the imaginary part of coherency -# computed on the same data, as well as comparing to other multivariate -# measures, such as canonical coherence :footcite:`VidaurreEtAl2019`. +# These scenarios are described in more detail in the :doc:`cacoh_vs_mic` +# example. + ############################################################################### # References From d69dc2116f7db96f95f5fe1e240a1c9af72c4b61 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Sun, 18 Feb 2024 22:04:47 +0100 Subject: [PATCH 38/59] Fix CITATION.cff --- CITATION.cff | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CITATION.cff b/CITATION.cff index 41ea1d98..b6dc8467 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -29,6 +29,6 @@ authors: title: "mne-connectivity" -version: 0.7.0 -date-released: 2024-XX-XX +version: 0.6.0 +date-released: 2024-12-06 url: "https://github.com/mne-tools/mne-connectivity" From aa8c3071eb48ad55bc9e5c0c9423362598ae855f Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Sun, 18 Feb 2024 22:05:15 +0100 Subject: [PATCH 39/59] Update docstrings --- mne_connectivity/spectral/epochs.py | 4 ++-- mne_connectivity/spectral/time.py | 4 ++-- mne_connectivity/utils/docs.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index e157799f..0cec5b2d 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -664,8 +664,8 @@ def spectral_connectivity_epochs( * %(gc)s * %(gc_tr)s - Multivariate methods (``['cacoh', 'mic', 'mim', 'gc', 'gc_tr]``) cannot - be called with the other methods. + Multivariate methods (``['cacoh', 'mic', 'mim', 'gc', 'gc_tr']``) + cannot be called with the other methods. indices : tuple of array | None Two arrays with indices of connections for which to compute connectivity. If a bivariate method is called, each array for the seeds diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 233de3c4..abba5b00 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -84,8 +84,8 @@ def spectral_connectivity_time( * %(gc)s * %(gc_tr)s - Multivariate methods (``['cacoh', 'mic', 'mim', 'gc', 'gc_tr]``) cannot - be called with the other methods. + Multivariate methods (``['cacoh', 'mic', 'mim', 'gc', 'gc_tr']``) + cannot be called with the other methods. average : bool Average connectivity scores over epochs. If ``True``, output will be an instance of :class:`SpectralConnectivity`, otherwise diff --git a/mne_connectivity/utils/docs.py b/mne_connectivity/utils/docs.py index 3c4430a3..fe56b1d9 100644 --- a/mne_connectivity/utils/docs.py +++ b/mne_connectivity/utils/docs.py @@ -70,7 +70,7 @@ docdict["coh"] = "'coh' : Coherence" docdict["cohy"] = "'cohy' : Coherency" -docdict["imcoh"] = "'imcoh' : Imaginary part of coherency" +docdict["imcoh"] = "'imcoh' : Imaginary part of Coherency" docdict["cacoh"] = "'cacoh' : Canonical Coherency (CaCoh)" docdict["mic"] = "'mic' : Maximised Imaginary part of Coherency (MIC)" docdict["mim"] = "'mim' : Multivariate Interaction Measure (MIM)" From 8ff932c14392293dc0fed4dc20cb1a1c703ea21e Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Sun, 18 Feb 2024 22:15:24 +0100 Subject: [PATCH 40/59] Update examples --- examples/cacoh.py | 7 +- ...vs_mic.py => compare_coherency_methods.py} | 130 ++++++++++++------ examples/mic_mim.py | 12 +- 3 files changed, 98 insertions(+), 51 deletions(-) rename examples/{cacoh_vs_mic.py => compare_coherency_methods.py} (79%) diff --git a/examples/cacoh.py b/examples/cacoh.py index c1394850..b169ff4f 100644 --- a/examples/cacoh.py +++ b/examples/cacoh.py @@ -45,7 +45,8 @@ # spatial filters to extract the relevant components of connectivity in a # frequency-resolved manner :footcite:`VidaurreEtAl2019`. It is similar to # multivariate methods based on the imaginary part of coherency (MIC & MIM -# :footcite:`EwaldEtAl2012`; see :doc:`mic_mim`), which are also supported by +# :footcite:`EwaldEtAl2012`; see :doc:`mic_mim` and +# :doc:`compare_coherency_methods`), which are also supported by # MNE-Connectivity. @@ -484,8 +485,8 @@ def simulate_connectivity(freq_band: tuple[int, int], rng_seed: int) -> np.ndarr # not risk biasing connectivity estimates towards interactions with particular # phase lags like MIC/MIM. # -# These scenarios are described in more detail in the :doc:`cacoh_vs_mic` -# example. +# These scenarios are described in more detail in the +# :doc:`compare_coherency_methods` example. ############################################################################### # References diff --git a/examples/cacoh_vs_mic.py b/examples/compare_coherency_methods.py similarity index 79% rename from examples/cacoh_vs_mic.py rename to examples/compare_coherency_methods.py index 8bf901bc..7c922341 100644 --- a/examples/cacoh_vs_mic.py +++ b/examples/compare_coherency_methods.py @@ -30,19 +30,22 @@ # - imaginary part of coherency (ImCoh) # - canonical coherency (CaCoh) # - maximised imaginary part of coherency (MIC) -# - multivariate interaction measure (MIM; based on ImCoh) +# - multivariate interaction measure (MIM) # # | # -# All of these methods centre on Cohy, a complex-valued estimate of -# of connectivity between signals in the frequency domain. +# All of these methods centre on Cohy, a complex-valued estimate of the +# correlation between signals in the frequency domain. It is an undirected +# measure of connectivity, being invariant to the direction of information flow +# between signals. # # The common approach for handling these complex-valued coherency scores is to # either take their absolute values (Coh) or their imaginary values (ImCoh # :footcite:`NolteEtAl2004`). # -# In addition to these traditional bivariate connectivity measures, advanced -# multivariate measures have also been developed based on Cohy (CaCoh +# In addition to these traditional bivariate connectivity measures (i.e. +# between two signals), advanced multivariate measures (i.e. between groups of +# signals) have also been developed based on Cohy (CaCoh # :footcite:`VidaurreEtAl2019`; can take the absolute value for a multivariate # form of Coh) or ImCoh (MIC & MIM :footcite:`EwaldEtAl2012`). # @@ -179,10 +182,6 @@ def simulate_connectivity( # %% -n_seeds = 3 -n_targets = 3 -n_channels = n_seeds + n_targets - # Generate connectivity indices seeds = [0, 1, 2, 6, 7, 8] targets = [3, 4, 5, 9, 10, 11] @@ -322,7 +321,8 @@ def plot_connectivity_circle(): # component (ImCoh, MIC, MIM) capture only non-zero time-lag interactions. # # The ability to capture these different interactions is not a feature specific -# to multivariate connectivity methods, as shown below for Coh and ImCoh. +# to multivariate connectivity methods, as shown below for the bivariate +# methods Coh and ImCoh. # %% @@ -361,11 +361,13 @@ def plot_connectivity_circle(): # assumed, methods based on only the imaginary part of coherency (ImCoh, MIC, # MIM) should be used.** Examples of situations include: # -# - Connectivity between channels of a single modality where volume conduction -# is expected, e.g. connectivity between EEG channels. -# - Connectivity between channels of different modalities where a common -# reference is used, e.g. connectivity between EEG and subcortical LFP using -# the same LFP reference. +# - Connectivity between channels of a single modality. +# - Connectivity between channels of different modalities where the same +# reference is used. +# +# Note that this applies not only to sensor-space signals, but also to +# source-space signals where remnants of these non-physiological interactions +# may remain even after source reconstruction. # # | # @@ -373,18 +375,14 @@ def plot_connectivity_circle(): # assumed, methods based on real and imaginary parts of coherency (Cohy, Coh, # CaCoh) should be used.** Examples of situations include: # -# - Connectivity between channels of a single modality where volume conduction -# is not expected, e.g. connectivity between ECoG channels. # - Connectivity between channels of different modalities where different -# references are used, e.g. connectivity between EEG and subcortical LFP -# using cortical and subcortical references, respectively. +# references are used. # # | # -# Although methods based on only the imaginary part of coherency should be used -# when non-physiological zero time-lag interactions are present, these methods -# should equally be avoided when such non-physiological interactions are -# absent. There are 2 key reasons: +# Equally, it is important to avoid methods based on only the imaginary part of +# coherency when non-physiological zero time-lag interactions are absent. There +# are two key reasons: # # **1. Discarding physiological zero time-lag interactions** # @@ -393,10 +391,10 @@ def plot_connectivity_circle(): # imaginary part of coherency may lead to information about genuine # connectivity being lost. # -# In situations where non-physiological zero time-lag -# interactions are present, the potential loss of physiological information is -# generally acceptable to avoid spurious connectivity estimates. However, -# unnecessarily discarding this information can be detrimental. +# In situations where non-physiological zero time-lag interactions are present, +# the potential loss of physiological information is generally acceptable to +# avoid spurious connectivity estimates. However, unnecessarily discarding this +# information can of course be detrimental. # # **2. Biasing interactions based on the angle of interaction** # @@ -466,13 +464,14 @@ def plot_connectivity_circle(): ############################################################################### # Plotting the connectivity values for CaCoh and MIC, we see how the 10-12 Hz -# and 23-25 Hz interactions only have a similar magnitude for CaCoh, whereas -# the MIC scores for the 10-12 Hz interaction are lower than for the 23-25 Hz +# and 23-25 Hz interactions have a similar magnitude for CaCoh, whereas the MIC +# scores for the 10-12 Hz interaction are lower than for the 23-25 Hz # interaction. # # This difference reflects the fact that as the angle of interaction deviates # from :math:`\pm` 90°, less information will be represented in the imaginary -# part of coherency. +# part of coherency. Accordingly, considering only the imaginary part of +# coherency can bias connectivity estimates based on the angle of interaction. # %% @@ -490,16 +489,10 @@ def plot_connectivity_circle(): fig.suptitle("CaCoh vs. MIC\n$\pm$45° & $\pm$90° interactions") ############################################################################### -# Accordingly, considering only the imaginary part of coherency can bias -# connectivity estimates based on the proximity of the phase angle of -# interactions to :math:`\pm` 90°, with closer angles leading to higher -# estimates of connectivity. -# -# Again, in situations where non-physiological zero time-lag interactions are -# present, this phase angle-dependent bias is generally acceptable to avoid -# spurious connectivity estimates. However, such a bias in situations where -# non-physiological zero time-lag interactions are not present is clearly -# problematic. +# In situations where non-physiological zero time-lag interactions are present, +# this phase angle-dependent bias is generally acceptable to avoid spurious +# connectivity estimates. However in situations where non-physiological zero +# time-lag interactions are not present, such a bias is clearly problematic. # # | # @@ -530,6 +523,54 @@ def plot_connectivity_circle(): axis.legend(loc="upper left") fig.suptitle("Coh vs. ImCoh\n$\pm$45° & $\pm$90° interactions") +############################################################################### +# Bivariate vs. multivariate coherency methods +# -------------------------------------------- +# +# As we have seen, coherency-based methods can be bivariate (Cohy, Coh, ImCoh) +# and multivariate (CaCoh, MIC, MIM). Whilst both forms capture the same +# information, there are several benefits to using multivariate methods when +# investigating connectivity between many signals. +# +# The multivariate methods can be used to capture the most relevant +# interactions between two groups of signals, representing this information in +# the component, rather than signal space. +# +# The dimensionality reduction associated with these methods offers: a much +# easier interpretation of the results; a higher signal-to-noise ratio compared +# to e.g. averaging bivariate connectivity estimates across multiple pairs of +# signals; and even reduced bias in what information is captured +# :footcite:`EwaldEtAl2012`. +# +# Furthermore, despite the dimensionality reduction of multivariate methods it +# is still possible to investigate the topographies of connectivity, with +# spatial patterns of connectivity being returned alongside the connectivity +# values themselves :footcite:`HaufeEtAl2014`. +# +# More information about the multivariate coherency-based methods can be found +# in the following examples: +# +# - CaCoh - :doc:`cacoh` +# - MIC & MIM - :doc:`mic_mim` + +############################################################################### +# Alternative approaches to computing connectivity +# ------------------------------------------------ +# +# Coherency-based methods are only some of the many approaches available in +# MNE-Connectivity for studying interactions between signals. Other +# non-directed measures include those based on the phase-lag index +# :footcite:`StamEtAl2007,VinckEtAl2011` (see also :doc:`dpli_wpli_pli`) and +# phase locking value :footcite:`LachauxEtAl1999,BrunaEtAl2018`. +# +# Furthermore, directed measures of connectivity which determine the direction +# of information flow are also available, including a variant of the phase-lag +# index :footcite:`StamEtAl2012` (see also :doc:`dpli_wpli_pli`), the phase +# slope index :footcite:`NolteEtAl2008` (see also +# :func:`mne_connectivity.phase_slope_index`), and Granger causality +# :footcite:`BarnettSeth2015,WinklerEtAl2016` (see also +# :doc:`granger_causality`). + ############################################################################### # Conclusion # ---------- @@ -538,12 +579,11 @@ def plot_connectivity_circle(): # methods are appropriate. # # Methods based on the imaginary part of coherency alone (ImCoh, MIC, MIM) -# should only be used when non-physiological zero time-lag interactions (e.g. -# volume conduction) are present. +# should be used when non-physiological zero time-lag interactions are present. # -# Methods based on the real and imaginary parts of coherency (Cohy, Coh, CaCoh) -# should only be used when non-physiological zero time-lag interactions are not -# present. +# In contrast, methods based on the real and imaginary parts of coherency +# (Cohy, Coh, CaCoh) should be used when non-physiological zero time-lag +# interactions are absent. ############################################################################### # References diff --git a/examples/mic_mim.py b/examples/mic_mim.py index 2d9b4bac..58a8b4f5 100644 --- a/examples/mic_mim.py +++ b/examples/mic_mim.py @@ -54,7 +54,8 @@ # following methods: the maximised imaginary part of coherency (MIC); and the # multivariate interaction measure (MIM). These methods are similar to the # multivariate method based on coherency (CaCoh :footcite:`VidaurreEtAl2019`; -# see :doc:`cacoh`), which is also supported by MNE-Connectivity. +# see :doc:`cacoh` and :doc:`compare_coherency_methods`), which is also +# supported by MNE-Connectivity. # # We start by loading some example MEG data and dividing it into # two-second-long epochs. @@ -173,6 +174,11 @@ # values, we can infer the existence of a dipole source between the central and # posterior regions of the left hemisphere accounting for the connectivity # contributions (represented on the plot as a green line). +# +# **N.B.** The spatial patterns are not a substitute for source reconstruction. +# If you need the spatial patterns in source space, you should perform source +# reconstruction before computing connectivity (see e.g. +# :doc:`mne_inverse_coherence_epochs`). # %% @@ -429,8 +435,8 @@ # not risk biasing connectivity estimates towards interactions with particular # phase lags like MIC/MIM. # -# These scenarios are described in more detail in the :doc:`cacoh_vs_mic` -# example. +# These scenarios are described in more detail in the +# :doc:`compare_coherency_methods` example. ############################################################################### From e095b35fb7136cb5eacd5975493c295d0a3c2009 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 19 Feb 2024 12:22:56 +0100 Subject: [PATCH 41/59] Replace N.B. with Note --- examples/compare_coherency_methods.py | 4 ++-- examples/mic_mim.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/compare_coherency_methods.py b/examples/compare_coherency_methods.py index 7c922341..ef1817e2 100644 --- a/examples/compare_coherency_methods.py +++ b/examples/compare_coherency_methods.py @@ -74,8 +74,8 @@ # having a method that can discard spurious zero time-lag connectivity # estimates is very desirable. # -# **N.B.** Not all zero time-lag interactions are necessarily non-physiological -# :footcite:`ViriyopaseEtAl2012`. +# **Note:** Not all zero time-lag interactions are necessarily +# non-physiological :footcite:`ViriyopaseEtAl2012`. # # To demonstrate the differences in how Cohy/Coh and ImCoh handle zero time-lag # interactions, we simulate two sets of data with: diff --git a/examples/mic_mim.py b/examples/mic_mim.py index 58a8b4f5..09e40f70 100644 --- a/examples/mic_mim.py +++ b/examples/mic_mim.py @@ -175,9 +175,9 @@ # posterior regions of the left hemisphere accounting for the connectivity # contributions (represented on the plot as a green line). # -# **N.B.** The spatial patterns are not a substitute for source reconstruction. -# If you need the spatial patterns in source space, you should perform source -# reconstruction before computing connectivity (see e.g. +# **Note:** The spatial patterns are not a substitute for source +# reconstruction. If you need the spatial patterns in source space, you should +# perform source reconstruction before computing connectivity (see e.g. # :doc:`mne_inverse_coherence_epochs`). # %% From 8652da468637d12b5ae7f9c3ee580cd76a55b3d4 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 19 Feb 2024 12:25:04 +0100 Subject: [PATCH 42/59] Reword cohy comparison example --- examples/compare_coherency_methods.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/compare_coherency_methods.py b/examples/compare_coherency_methods.py index ef1817e2..9c8e09fe 100644 --- a/examples/compare_coherency_methods.py +++ b/examples/compare_coherency_methods.py @@ -380,8 +380,8 @@ def plot_connectivity_circle(): # # | # -# Equally, it is important to avoid methods based on only the imaginary part of -# coherency when non-physiological zero time-lag interactions are absent. There +# Equally, when there are no non-physiological zero time-lag interactions, one +# should not used methods based on only the imaginary part of coherency. There # are two key reasons: # # **1. Discarding physiological zero time-lag interactions** From 81ffefea697e61c3ec9c1bd72013bd74f83417ac Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 19 Feb 2024 18:53:42 +0100 Subject: [PATCH 43/59] Update CaCoh & MIC/MIM examples --- examples/cacoh.py | 44 +++++++++++++++++++++----------------------- examples/mic_mim.py | 6 +++--- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/examples/cacoh.py b/examples/cacoh.py index b169ff4f..04351fb3 100644 --- a/examples/cacoh.py +++ b/examples/cacoh.py @@ -55,8 +55,7 @@ # --------------- # # To demonstrate the CaCoh method, will we use some simulated data consisting -# of an interaction between signals in a given frequency range. Here, we -# simulate two sets of interactions: +# of two sets of interactions between signals in a given frequency range: # # - 5 seeds and 3 targets interacting in the 10-12 Hz frequency range. # - 5 seeds and 3 targets interacting in the 23-25 Hz frequency range. @@ -195,6 +194,10 @@ def simulate_connectivity(freq_band: tuple[int, int], rng_seed: int) -> np.ndarr cacoh = spectral_connectivity_epochs( data, method="cacoh", indices=multivar_indices, sfreq=100, fmin=3, fmax=35 ) +print(f"Results shape: {cacoh.get_data().shape} (connections x frequencies)") + +# Get absolute CaCoh +cacoh_abs = np.abs(cacoh.get_data())[0] ############################################################################### # As you can see below, using CaCoh we have summarised the most relevant @@ -205,11 +208,9 @@ def simulate_connectivity(freq_band: tuple[int, int], rng_seed: int) -> np.ndarr # %% -print(f"Results shape: {cacoh.get_data().shape} (connections x frequencies)") - # Plot CaCoh fig, axis = plt.subplots(1, 1) -axis.plot(cacoh.freqs, np.abs(cacoh.get_data()[0]), linewidth=2) +axis.plot(cacoh.freqs, cacoh_abs, linewidth=2) axis.set_xlabel("Frequency (Hz)") axis.set_ylabel("Connectivity (A.U.)") fig.suptitle("CaCoh") @@ -217,7 +218,7 @@ def simulate_connectivity(freq_band: tuple[int, int], rng_seed: int) -> np.ndarr ############################################################################### # Note that we plot the absolute values of the results (coherence) rather than # the complex values (coherency). The absolute value of connectivity will -# generally be of most interest, however information such as the phase of +# generally be of most interest. However, information such as the phase of # interaction can only be extracted from the complex-valued results, e.g. with # the :func:`numpy.angle` function. @@ -250,6 +251,11 @@ def simulate_connectivity(freq_band: tuple[int, int], rng_seed: int) -> np.ndarr coh = spectral_connectivity_epochs( data, method="coh", indices=bivar_indices, sfreq=100, fmin=3, fmax=35 ) +print(f"Original results shape: {coh.get_data().shape} (connections x frequencies)") + +# Average results across connections +coh_mean = np.mean(coh.get_data(), axis=0) +print(f"Averaged results shape: {coh_mean.shape} (connections x frequencies)") ############################################################################### # Plotting the bivariate and multivariate results together, we can see that @@ -260,19 +266,10 @@ def simulate_connectivity(freq_band: tuple[int, int], rng_seed: int) -> np.ndarr # %% -print(f"Results shape: {coh.get_data().shape} (connections x frequencies)") - -cacoh_min = np.min(np.abs(cacoh.get_data()[0])) -coh_min = np.min(np.mean(coh.get_data(), axis=0)) - # Plot CaCoh & Coh fig, axis = plt.subplots(1, 1) -axis.plot( - coh.freqs, np.abs(cacoh.get_data()[0]) - cacoh_min, linewidth=2, label="CaCoh" -) -axis.plot( - coh.freqs, np.mean(coh.get_data(), axis=0) - coh_min, linewidth=2, label="Coh" -) +axis.plot(cacoh.freqs, cacoh_abs - np.min(cacoh_abs), linewidth=2, label="CaCoh") +axis.plot(coh.freqs, coh_mean - np.min(coh_mean), linewidth=2, label="Coh") axis.set_xlabel("Frequency (Hz)") axis.set_ylabel("Baseline-corrected connectivity (A.U.)") axis.legend() @@ -460,24 +457,25 @@ def simulate_connectivity(freq_band: tuple[int, int], rng_seed: int) -> np.ndarr # gets the singular values of the data across epochs s = np.linalg.svd(data, compute_uv=False).min(axis=0) # finds how many singular values are 'close' to the largest singular value -rank = np.count_nonzero(s >= s[0] * 1e-4) # 1e-4 is the 'closeness' criteria +rank = np.count_nonzero(s >= s[0] * 1e-4) # 1e-4 is the 'closeness' criteria, which is +# a hyper-parameter ############################################################################### # Limitations # ----------- # # Multivariate methods offer many benefits in the form of dimensionality -# reduction and signal-to-noise ratio improvements, however no method is +# reduction and signal-to-noise ratio improvements. However, no method is # perfect. When we simulated the data, we mentioned how we considered the seeds # and targets to be signals of different modalities. This is an important # factor in whether CaCoh should be used over methods based solely on the # imaginary part of coherency such as MIC and MIM. # # In short, if you want to examine connectivity between signals from the same -# modality, you should consider not using CaCoh. Instead, methods based on the -# imaginary part of coherency such as MIC and MIM should be used to avoid -# spurious connectivity estimates stemming from e.g. volume conduction -# artefacts. +# modality, you should consider using another method instead of CaCoh. Rather, +# methods based on the imaginary part of coherency such as MIC and MIM should +# be used to avoid spurious connectivity estimates stemming from e.g. volume +# conduction artefacts. # # On the other hand, if you want to examine connectivity between signals from # different modalities, CaCoh is a more appropriate method than MIC/MIM. This diff --git a/examples/mic_mim.py b/examples/mic_mim.py index 09e40f70..511cc2a8 100644 --- a/examples/mic_mim.py +++ b/examples/mic_mim.py @@ -54,8 +54,7 @@ # following methods: the maximised imaginary part of coherency (MIC); and the # multivariate interaction measure (MIM). These methods are similar to the # multivariate method based on coherency (CaCoh :footcite:`VidaurreEtAl2019`; -# see :doc:`cacoh` and :doc:`compare_coherency_methods`), which is also -# supported by MNE-Connectivity. +# see :doc:`cacoh` and :doc:`compare_coherency_methods`). # # We start by loading some example MEG data and dividing it into # two-second-long epochs. @@ -411,7 +410,8 @@ # gets the singular values of the data s = np.linalg.svd(raw.get_data(), compute_uv=False) # finds how many singular values are 'close' to the largest singular value -rank = np.count_nonzero(s >= s[0] * 1e-4) # 1e-4 is the 'closeness' criteria +rank = np.count_nonzero(s >= s[0] * 1e-4) # 1e-4 is the 'closeness' criteria, which is +# a hyper-parameter ############################################################################### From e9497fff87bfca9327e3df3cea319c485aba478a Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 19 Feb 2024 18:55:00 +0100 Subject: [PATCH 44/59] Update CaCoh example Co-authored-by: Adam Li --- examples/cacoh.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/cacoh.py b/examples/cacoh.py index 04351fb3..1e8d5667 100644 --- a/examples/cacoh.py +++ b/examples/cacoh.py @@ -46,15 +46,15 @@ # frequency-resolved manner :footcite:`VidaurreEtAl2019`. It is similar to # multivariate methods based on the imaginary part of coherency (MIC & MIM # :footcite:`EwaldEtAl2012`; see :doc:`mic_mim` and -# :doc:`compare_coherency_methods`), which are also supported by -# MNE-Connectivity. +# :doc:`compare_coherency_methods`). ############################################################################### # Data Simulation # --------------- # -# To demonstrate the CaCoh method, will we use some simulated data consisting +# To demonstrate the CaCoh method, we will use some simulated data consisting + # of two sets of interactions between signals in a given frequency range: # # - 5 seeds and 3 targets interacting in the 10-12 Hz frequency range. From 1022e83d133bf4b860d067f855d9147ccb07168c Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 19 Feb 2024 18:59:21 +0100 Subject: [PATCH 45/59] Remove empty line CaCoh example --- examples/cacoh.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/cacoh.py b/examples/cacoh.py index 1e8d5667..649e73c2 100644 --- a/examples/cacoh.py +++ b/examples/cacoh.py @@ -54,7 +54,6 @@ # --------------- # # To demonstrate the CaCoh method, we will use some simulated data consisting - # of two sets of interactions between signals in a given frequency range: # # - 5 seeds and 3 targets interacting in the 10-12 Hz frequency range. From 2d46f69f87bb4623beb28851ba86493947091235 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 19 Feb 2024 20:09:55 +0100 Subject: [PATCH 46/59] Update examples from review --- examples/cacoh.py | 11 ++++++--- examples/compare_coherency_methods.py | 3 ++- examples/mic_mim.py | 35 ++++++++++++++++++--------- 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/examples/cacoh.py b/examples/cacoh.py index 649e73c2..22a255a7 100644 --- a/examples/cacoh.py +++ b/examples/cacoh.py @@ -42,10 +42,13 @@ # that are present between only a small number of channels. # # Canonical coherency (CaCoh) is a multivariate form of coherency that uses -# spatial filters to extract the relevant components of connectivity in a -# frequency-resolved manner :footcite:`VidaurreEtAl2019`. It is similar to -# multivariate methods based on the imaginary part of coherency (MIC & MIM -# :footcite:`EwaldEtAl2012`; see :doc:`mic_mim` and +# eigendecomposition-derived spatial filters to extract the underlying +# components of connectivity in a frequency-resolved manner +# :footcite:`VidaurreEtAl2019`. This approach goes beyond simply aggregating +# information across all possible combinations of signals. +# +# It is similar to multivariate methods based on the imaginary part of +# coherency (MIC & MIM :footcite:`EwaldEtAl2012`; see :doc:`mic_mim` and # :doc:`compare_coherency_methods`). diff --git a/examples/compare_coherency_methods.py b/examples/compare_coherency_methods.py index 9c8e09fe..a823e8d4 100644 --- a/examples/compare_coherency_methods.py +++ b/examples/compare_coherency_methods.py @@ -47,7 +47,8 @@ # between two signals), advanced multivariate measures (i.e. between groups of # signals) have also been developed based on Cohy (CaCoh # :footcite:`VidaurreEtAl2019`; can take the absolute value for a multivariate -# form of Coh) or ImCoh (MIC & MIM :footcite:`EwaldEtAl2012`). +# form of Coh; see :doc:`cacoh`) or ImCoh (MIC & MIM :footcite:`EwaldEtAl2012`; +# see :doc:`mic_mim`). # # Despite their similarities, there are distinct scenarios in which these # different methods are most appropriate, as we will show in this example. diff --git a/examples/mic_mim.py b/examples/mic_mim.py index 511cc2a8..67c6e6d3 100644 --- a/examples/mic_mim.py +++ b/examples/mic_mim.py @@ -42,19 +42,32 @@ # A popular bivariate measure of connectivity is the imaginary part of # coherency, which looks at the correlation between two signals in the # frequency domain and is immune to spurious connectivity arising from volume -# conduction artefacts :footcite:`NolteEtAl2004`. However, depending on the -# degree of source mixing, this measure is susceptible to biased estimates of +# conduction artefacts :footcite:`NolteEtAl2004`. However, in cases where +# interactions between multiple signals are of interest, computing connectivity +# between all possible combinations of signals leads to a very large number of +# results which is difficult to interpret. A common approach is to average +# results across these connections, however this risks reducing the +# signal-to-noise ratio of results and burying interactions that are present +# between only a small number of channels. +# +# Additionally, this bivariate measure is susceptible to biased estimates of # connectivity based on the spatial proximity of sensors -# :footcite:`EwaldEtAl2012`. +# :footcite:`EwaldEtAl2012` depending on the degree of source mixing in the +# signals. +# +# To overcome this limitation, spatial filters derived from eigendecompositions +# allows connectivity to be analysed in a multivariate manner, removing the +# source mixing-dependent bias and increase the signal-to-noise ratio of +# connectivity estimates :footcite:`EwaldEtAl2012`. This approach goes beyond +# simply aggregating information across all possible combinations of signals, +# extracting the underlying components of connectivity in a frequency-resolved +# manner. # -# To overcome this limitation, spatial filters can be used to estimate -# connectivity free from this source mixing-dependent bias, which additionally -# increases the signal-to-noise ratio and allows signals to be analysed in a -# multivariate manner :footcite:`EwaldEtAl2012`. This approach leads to the -# following methods: the maximised imaginary part of coherency (MIC); and the -# multivariate interaction measure (MIM). These methods are similar to the -# multivariate method based on coherency (CaCoh :footcite:`VidaurreEtAl2019`; -# see :doc:`cacoh` and :doc:`compare_coherency_methods`). +# This leads to the following methods: the maximised imaginary part of +# coherency (MIC); and the multivariate interaction measure (MIM). These +# methods are similar to the multivariate method based on coherency (CaCoh +# :footcite:`VidaurreEtAl2019`; see :doc:`cacoh` and +# :doc:`compare_coherency_methods`). # # We start by loading some example MEG data and dividing it into # two-second-long epochs. From 34af00a6a30d9e48a2ba93a9ed28725d6f4e61a2 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 20 Feb 2024 01:56:10 +0100 Subject: [PATCH 47/59] Add simulation function --- doc/api.rst | 12 +- mne_connectivity/__init__.py | 1 + mne_connectivity/datasets/__init__.py | 1 + mne_connectivity/datasets/frequency.py | 146 ++++++++++++++++++++ mne_connectivity/tests/test_datasets.py | 169 ++++++++++++++++++++++++ 5 files changed, 328 insertions(+), 1 deletion(-) create mode 100644 mne_connectivity/datasets/__init__.py create mode 100644 mne_connectivity/datasets/frequency.py create mode 100644 mne_connectivity/tests/test_datasets.py diff --git a/doc/api.rst b/doc/api.rst index 300601b4..f919f74b 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -86,4 +86,14 @@ Visualization functions :toctree: generated/ plot_sensors_connectivity - plot_connectivity_circle \ No newline at end of file + plot_connectivity_circle + +Dataset functions +================= + +.. currentmodule:: mne_connectivity + +.. autosummary:: + :toctree: generated/ + + make_signals_in_freq_bands \ No newline at end of file diff --git a/mne_connectivity/__init__.py b/mne_connectivity/__init__.py index fc55e2d0..6570a011 100644 --- a/mne_connectivity/__init__.py +++ b/mne_connectivity/__init__.py @@ -19,6 +19,7 @@ ) from .effective import phase_slope_index from .envelope import envelope_correlation, symmetric_orth +from .datasets import make_signals_in_freq_bands from .io import read_connectivity from .spectral import spectral_connectivity_epochs, spectral_connectivity_time from .utils import ( diff --git a/mne_connectivity/datasets/__init__.py b/mne_connectivity/datasets/__init__.py new file mode 100644 index 00000000..d5c8e2eb --- /dev/null +++ b/mne_connectivity/datasets/__init__.py @@ -0,0 +1 @@ +from .frequency import make_signals_in_freq_bands diff --git a/mne_connectivity/datasets/frequency.py b/mne_connectivity/datasets/frequency.py new file mode 100644 index 00000000..7f688184 --- /dev/null +++ b/mne_connectivity/datasets/frequency.py @@ -0,0 +1,146 @@ +# Authors: Adam Li +# Thomas S. Binns +# +# License: BSD (3-clause) + +import numpy as np +from mne import EpochsArray, create_info +from mne.filter import filter_data + + +def make_signals_in_freq_bands( + n_seeds, + n_targets, + freq_band, + n_epochs=10, + n_times=200, + sfreq=100, + trans_bandwidth=1, + snr=0.7, + connection_delay=5, + tmin=0, + ch_names=None, + ch_types="eeg", + rng_seed=None, +): + """Simulate signals interacting in a given frequency band. + + Parameters + ---------- + n_seeds : int + Number of seed channels to simulate. + n_targets : int + Number of target channels to simulate. + freq_band : tuple of int or float + Frequency band where the connectivity should be simulated, where the first entry + corresponds to the lower frequency, and the second entry to the higher + frequency. + n_epochs : int (default 10) + Number of epochs in the simulated data. + n_times : int (default 200) + Number of timepoints each epoch of the simulated data. + sfreq : int | float (default 100) + Sampling frequency of the simulated data, in Hz. + trans_bandwidth : int | float (default 1) + Transition bandwidth of the filter to apply to isolate activity in + ``freq_band``, in Hz. + snr : float (default 0.7) + Signal-to-noise ratio of the simulated data in the range [0, 1]. + connection_delay : int (default 5) + Number of timepoints for the delay of connectivity between the seeds and + targets. If > 0, the target data is a delayed form of the seed data. If < 0, the + seed data is a delayed form of the target data. + tmin : int | float (default 0) + Earliest time of each epoch. + ch_names : list of str | None (default None) + Names of the channels in the simulated data. If `None`, the channels are named + according to their index and the frequency band of interaction. + ch_types : str | list of str (default "eeg") + Types of the channels in the simulated data. + rng_seed : int | None (default None) + Seed to use for the random number generator. If `None`, no seed is specified. + + Returns + ------- + epochs : mne.EpochsArray + The simulated data stored in an `mne.EpochsArray` object. The channels are + arranged according to seeds, then targets. + + Notes + ----- + Signals are simulated as a single source of activity in the given frequency band and + projected into ``n_seeds + n_targets`` noise channels. + """ + n_channels = n_seeds + n_targets + + # check inputs + if n_seeds < 1 or n_targets < 1: + raise ValueError("Number of seeds and targets must each be at least 1.") + + if not isinstance(freq_band, tuple): + raise TypeError("Frequency band must be a tuple.") + if len(freq_band) != 2: + raise ValueError("Frequency band must contain two numbers.") + + if n_times < 1: + raise ValueError("Number of timepoints must be at least 1.") + + if n_epochs < 1: + raise ValueError("Number of epochs must be at least 1.") + + if sfreq <= 0: + raise ValueError("Sampling frequency must be > 0.") + + if snr < 0 or snr > 1: + raise ValueError("Signal-to-noise ratio must be between 0 and 1.") + + if np.abs(connection_delay) >= n_epochs * n_times: + raise ValueError( + "Connection delay must be less than the total number of timepoints." + ) + + # simulate data + rng = np.random.RandomState(rng_seed) + + # simulate signal source at desired frequency band + signal = rng.randn(1, n_epochs * n_times + np.abs(connection_delay)) + signal = filter_data( + data=signal, + sfreq=sfreq, + l_freq=freq_band[0], + h_freq=freq_band[1], + l_trans_bandwidth=trans_bandwidth, + h_trans_bandwidth=trans_bandwidth, + fir_design="firwin2", + ) + + # simulate noise for each channel + noise = rng.randn(n_channels, n_epochs * n_times + np.abs(connection_delay)) + + # create data by projecting signal into each channel of noise + data = (signal * snr) + (noise * (1 - snr)) + + # shift data by desired delay and remove extra time + if connection_delay != 0: + if connection_delay > 0: + delay_chans = np.arange(n_seeds, n_channels) # delay targets + else: + delay_chans = np.arange(0, n_seeds) # delay seeds + data[delay_chans, np.abs(connection_delay) :] = data[ + delay_chans, : n_epochs * n_times + ] + data = data[:, : n_epochs * n_times] + + # reshape data into epochs + data = data.reshape(n_channels, n_epochs, n_times) + data = data.transpose((1, 0, 2)) # (epochs x channels x times) + + # store data in an MNE EpochsArray object + if ch_names is None: + ch_names = [ + f"{ch_i}_{freq_band[0]}_{freq_band[1]}" for ch_i in range(n_channels) + ] + info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) + epochs = EpochsArray(data=data, info=info, tmin=tmin) + + return epochs diff --git a/mne_connectivity/tests/test_datasets.py b/mne_connectivity/tests/test_datasets.py new file mode 100644 index 00000000..1b7c2398 --- /dev/null +++ b/mne_connectivity/tests/test_datasets.py @@ -0,0 +1,169 @@ +import numpy as np +import pytest + +from mne_connectivity import ( + make_signals_in_freq_bands, + seed_target_indices, + spectral_connectivity_epochs, +) + + +@pytest.mark.parametrize("n_seeds", [1, 3]) +@pytest.mark.parametrize("n_targets", [1, 3]) +@pytest.mark.parametrize("snr", [0.7, 0.4]) +@pytest.mark.parametrize("connection_delay", [0, 3, -3]) +@pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) +def test_make_signals_in_freq_bands(n_seeds, n_targets, snr, connection_delay, mode): + """Test `make_signals_in_freq_bands` simulates connectivity properly.""" + # Case with no spurious correlations (avoids tests randomly failing) + rng_seed = 0 + + # Simulate data + freq_band = (5, 10) # fmin, fmax (Hz) + sfreq = 100 # Hz + trans_bandwidth = 1 # Hz + data = make_signals_in_freq_bands( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=freq_band, + n_epochs=30, + n_times=200, + sfreq=sfreq, + trans_bandwidth=trans_bandwidth, + snr=snr, + connection_delay=connection_delay, + rng_seed=rng_seed, + ) + + # Compute connectivity + methods = ["coh", "imcoh", "dpli"] + indices = seed_target_indices( + seeds=np.arange(n_seeds), targets=np.arange(n_targets) + n_seeds + ) + fmin = 3 + fmax = sfreq // 2 + if mode == "cwt_morlet": + cwt_params = {"cwt_freqs": np.arange(fmin, fmax), "cwt_n_cycles": 3.5} + else: + cwt_params = dict() + con = spectral_connectivity_epochs( + data, + method=methods, + indices=indices, + mode=mode, + fmin=fmin, + fmax=fmax, + **cwt_params, + ) + freqs = np.array(con[0].freqs) + + # Define expected connectivity values + thresh_good = dict() + thresh_bad = dict() + # Coh + thresh_good["coh"] = (0.2, 0.9) + thresh_bad["coh"] = (0.0, 0.2) + # ImCoh + if connection_delay == 0: + thresh_good["imcoh"] = (0.0, 0.2) + thresh_bad["imcoh"] = (0.0, 0.2) + else: + thresh_good["imcoh"] = (0.2, 0.8) + thresh_bad["imcoh"] = (0.0, 0.2) + # DPLI + if connection_delay == 0: + thresh_good["dpli"] = (0.3, 0.6) + thresh_bad["dpli"] = (0.3, 0.6) + elif connection_delay > 0: + thresh_good["dpli"] = (0.5, 1) + thresh_bad["dpli"] = (0.3, 0.6) + else: + thresh_good["dpli"] = (0, 0.5) + thresh_bad["dpli"] = (0.3, 0.6) + + # Check connectivity values are acceptable + freqs_good = np.argwhere( + (freqs >= freq_band[0]) & (freqs <= freq_band[1]) + ).flatten() + freqs_bad = np.argwhere( + (freqs < freq_band[0] - trans_bandwidth * 2) + | (freqs > freq_band[1] + trans_bandwidth * 2) + ).flatten() + for method_name, method_con in zip(methods, con): + con_values = method_con.get_data() + if method_name == "imcoh": + con_values = np.abs(con_values) + # freq. band of interest + con_values_good = np.mean(con_values[:, freqs_good]) + assert ( + con_values_good >= thresh_good[method_name][0] + and con_values_good <= thresh_good[method_name][1] + ) + + # other freqs. + con_values_bad = np.mean(con_values[:, freqs_bad]) + assert ( + con_values_bad >= thresh_bad[method_name][0] + and con_values_bad <= thresh_bad[method_name][1] + ) + + +def test_make_signals_error_catch(): + """Test error catching for `make_signals_in_freq_bands`.""" + freq_band = (5, 10) + + # check bad n_seeds/targets caught + with pytest.raises( + ValueError, match="Number of seeds and targets must each be at least 1." + ): + make_signals_in_freq_bands(n_seeds=0, n_targets=1, freq_band=freq_band) + with pytest.raises( + ValueError, match="Number of seeds and targets must each be at least 1." + ): + make_signals_in_freq_bands(n_seeds=1, n_targets=0, freq_band=freq_band) + + # check bad freq_band caught + with pytest.raises(TypeError, match="Frequency band must be a tuple."): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=1) + with pytest.raises(ValueError, match="Frequency band must contain two numbers."): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=(1, 2, 3)) + + # check bad n_times + with pytest.raises(ValueError, match="Number of timepoints must be at least 1."): + make_signals_in_freq_bands( + n_seeds=1, n_targets=1, freq_band=freq_band, n_times=0 + ) + + # check bad n_epochs + with pytest.raises(ValueError, match="Number of epochs must be at least 1."): + make_signals_in_freq_bands( + n_seeds=1, n_targets=1, freq_band=freq_band, n_epochs=0 + ) + + # check bad sfreq + with pytest.raises(ValueError, match="Sampling frequency must be > 0."): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, sfreq=0) + + # check bad snr + with pytest.raises( + ValueError, match="Signal-to-noise ratio must be between 0 and 1." + ): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, snr=-1) + with pytest.raises( + ValueError, match="Signal-to-noise ratio must be between 0 and 1." + ): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, snr=2) + + # check bad connection_delay + with pytest.raises( + ValueError, + match="Connection delay must be less than the total number of timepoints.", + ): + make_signals_in_freq_bands( + n_seeds=1, + n_targets=1, + freq_band=freq_band, + n_epochs=1, + n_times=1, + connection_delay=1, + ) From 9be5e8fa91f3b087cd8cdc367c53f340ca2286d0 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Fri, 1 Mar 2024 13:23:20 +0100 Subject: [PATCH 48/59] D comes before E --- mne_connectivity/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/__init__.py b/mne_connectivity/__init__.py index 6570a011..d271c706 100644 --- a/mne_connectivity/__init__.py +++ b/mne_connectivity/__init__.py @@ -17,9 +17,9 @@ SpectroTemporalConnectivity, TemporalConnectivity, ) +from .datasets import make_signals_in_freq_bands from .effective import phase_slope_index from .envelope import envelope_correlation, symmetric_orth -from .datasets import make_signals_in_freq_bands from .io import read_connectivity from .spectral import spectral_connectivity_epochs, spectral_connectivity_time from .utils import ( From 0559e9ab52a1088321c471fa0945704d350134a9 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 4 Mar 2024 19:21:09 +0100 Subject: [PATCH 49/59] Mark TODO replace simulations --- mne_connectivity/spectral/tests/test_spectral.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 96e18d6c..deed3aac 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -21,6 +21,7 @@ from mne_connectivity.spectral.epochs_bivariate import _CohEst +# TODO: Replace with `make_signals_in_freq_bands` after tweaking tolerances in tests def create_test_dataset( sfreq, n_signals, n_epochs, n_times, tmin, tmax, fstart, fend, trans_bandwidth=2.0 ): @@ -114,6 +115,7 @@ def test_spectral_connectivity_parallel(method, mode, tmp_path): n_times = 256 n_jobs = 2 # test with parallelization + # TODO: Replace with `make_signals_in_freq_bands` after tweaking tolerances in tests data = rng.randn(n_signals, n_epochs * n_times) # simulate connectivity from 5Hz..15Hz fstart, fend = 5.0, 15.0 @@ -210,6 +212,7 @@ def test_spectral_connectivity(method, mode): # 5Hz..15Hz fstart, fend = 5.0, 15.0 + # TODO: Replace with `make_signals_in_freq_bands` after tweaking tolerances in tests data, times_data = create_test_dataset( sfreq, n_signals=n_signals, @@ -501,6 +504,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # 15-25 Hz connectivity fstart, fend = 15.0, 25.0 rng = np.random.RandomState(0) + # TODO: Replace with `make_signals_in_freq_bands` after tweaking tolerances in tests data = rng.randn(n_signals, n_epochs * n_times + delay) # simulate connectivity from fstart to fend data[n_seeds:, :] = filter_data( @@ -1183,6 +1187,7 @@ def test_spectral_connectivity_time_delayed(): # 20-30 Hz connectivity fstart, fend = 20.0, 30.0 rng = np.random.RandomState(0) + # TODO: Replace with `make_signals_in_freq_bands` after tweaking tolerances in tests data = rng.randn(n_signals, n_epochs * n_times + delay) # simulate connectivity from fstart to fend data[n_seeds:, :] = filter_data( @@ -1313,6 +1318,7 @@ def test_spectral_connectivity_time_resolved(method, mode): tmax = (n_times - 1) / sfreq # 5Hz..15Hz fstart, fend = 5.0, 15.0 + # TODO: Replace with `make_signals_in_freq_bands` after tweaking tolerances in tests data, _ = create_test_dataset( sfreq, n_signals=n_signals, @@ -1372,6 +1378,7 @@ def test_spectral_connectivity_time_padding(method, mode, padding): tmax = (n_times - 1) / sfreq # 5Hz..15Hz fstart, fend = 5.0, 15.0 + # TODO: Replace with `make_signals_in_freq_bands` after tweaking tolerances in tests data, _ = create_test_dataset( sfreq, n_signals=n_signals, From 286c18d18da81f514c720248649d579a67393cff Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 4 Mar 2024 20:00:38 +0100 Subject: [PATCH 50/59] Replace trivial simulations --- .../spectral/tests/test_spectral.py | 324 ++++++++---------- 1 file changed, 135 insertions(+), 189 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index deed3aac..8c493cb3 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -1,3 +1,4 @@ +import inspect import os import numpy as np @@ -9,6 +10,7 @@ from mne_connectivity import ( SpectralConnectivity, + make_signals_in_freq_bands, read_connectivity, spectral_connectivity_epochs, spectral_connectivity_time, @@ -104,35 +106,18 @@ def _stc_gen(data, sfreq, tmin, combo=False): @pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) def test_spectral_connectivity_parallel(method, mode, tmp_path): """Test saving spectral connectivity with parallel functions.""" - # Use a case known to have no spurious correlations (it would bad if - # tests could randomly fail): - rng = np.random.RandomState(0) - trans_bandwidth = 2.0 - - sfreq = 50.0 - n_signals = 3 - n_epochs = 8 - n_times = 256 n_jobs = 2 # test with parallelization - # TODO: Replace with `make_signals_in_freq_bands` after tweaking tolerances in tests - data = rng.randn(n_signals, n_epochs * n_times) - # simulate connectivity from 5Hz..15Hz - fstart, fend = 5.0, 15.0 - data[1, :] = filter_data( - data[0, :], - sfreq, - fstart, - fend, - filter_length="auto", - fir_design="firwin2", - l_trans_bandwidth=trans_bandwidth, - h_trans_bandwidth=trans_bandwidth, + data = make_signals_in_freq_bands( + n_seeds=2, + n_targets=1, + freq_band=(5, 15), + n_epochs=8, + n_times=256, + sfreq=50, + trans_bandwidth=2.0, + rng_seed=0, # case with no spurious correlations (avoid tests randomly failing) ) - # add some noise, so the spectrum is not exactly zero - data[1, :] += 1e-2 * rng.randn(n_times * n_epochs) - data = data.reshape(n_signals, n_epochs, n_times) - data = np.transpose(data, [1, 0, 2]) # define some frequencies for cwt cwt_freqs = np.arange(3, 24.5, 1) @@ -160,7 +145,6 @@ def test_spectral_connectivity_parallel(method, mode, tmp_path): method=method, mode=mode, indices=None, - sfreq=sfreq, mt_adaptive=adaptive, mt_low_bias=True, mt_bandwidth=mt_bandwidth, @@ -736,12 +720,17 @@ def test_multivariate_spectral_connectivity_epochs_regression(): @pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): """Test error catching for multivar. freq.-domain connectivity methods.""" - sfreq = 50.0 - n_signals = 4 # Do not change! - n_epochs = 8 - n_times = 256 - rng = np.random.RandomState(0) - data = rng.randn(n_epochs, n_signals, n_times) + sfreq = 50 # Hz + data = make_signals_in_freq_bands( + n_seeds=2, # do not change! + n_targets=2, # do not change! + freq_band=(10, 20), # arbitrary for this test + n_epochs=8, + n_times=256, + sfreq=sfreq, + rng_seed=0, + ) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) cwt_freqs = np.arange(10, 25 + 1) @@ -751,12 +740,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): ): non_nested_indices = (np.array([0, 1]), np.array([2, 3])) spectral_connectivity_epochs( - data, - method=method, - mode=mode, - indices=non_nested_indices, - sfreq=sfreq, - gc_n_lags=10, + data, method=method, mode=mode, indices=non_nested_indices, gc_n_lags=10 ) # check bad indices with repeated channels caught @@ -765,12 +749,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): ): repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) spectral_connectivity_epochs( - data, - method=method, - mode=mode, - indices=repeated_indices, - sfreq=sfreq, - gc_n_lags=10, + data, method=method, mode=mode, indices=repeated_indices, gc_n_lags=10 ) # check mixed methods caught @@ -780,12 +759,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): elif isinstance(method, list): mixed_methods = [*method, "coh"] spectral_connectivity_epochs( - data, - method=mixed_methods, - mode=mode, - indices=indices, - sfreq=sfreq, - cwt_freqs=cwt_freqs, + data, method=mixed_methods, mode=mode, indices=indices, cwt_freqs=cwt_freqs ) # check bad rank args caught @@ -796,7 +770,6 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): method=method, mode=mode, indices=indices, - sfreq=sfreq, rank=too_low_rank, cwt_freqs=cwt_freqs, ) @@ -807,7 +780,6 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): method=method, mode=mode, indices=indices, - sfreq=sfreq, rank=too_high_rank, cwt_freqs=cwt_freqs, ) @@ -818,7 +790,6 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): method=method, mode=mode, indices=indices, - sfreq=sfreq, rank=too_few_rank, cwt_freqs=cwt_freqs, ) @@ -829,13 +800,16 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): method=method, mode=mode, indices=indices, - sfreq=sfreq, rank=too_much_rank, cwt_freqs=cwt_freqs, ) # check rank-deficient data caught - bad_data = data.copy() + # XXX: remove logic once support for mne<1.6 is dropped + kwargs = dict() + if "copy" in inspect.getfullargspec(data.get_data).kwonlyargs: + kwargs["copy"] = False + bad_data = data.get_data(**kwargs) bad_data[:, 1] = bad_data[:, 0] bad_data[:, 3] = bad_data[:, 2] assert np.all(np.linalg.matrix_rank(bad_data[:, (0, 1), :]) == 1) @@ -876,7 +850,6 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): method=method, mode=mode, indices=indices, - sfreq=sfreq, fmin=frange[0], fmax=frange[1], gc_n_lags=n_lags, @@ -886,12 +859,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): # check no indices caught with pytest.raises(ValueError, match="indices must be specified"): spectral_connectivity_epochs( - data, - method=method, - mode=mode, - indices=None, - sfreq=sfreq, - cwt_freqs=cwt_freqs, + data, method=method, mode=mode, indices=None, cwt_freqs=cwt_freqs ) # check intersecting indices caught @@ -900,12 +868,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): ValueError, match="seed and target indices must not intersect" ): spectral_connectivity_epochs( - data, - method=method, - mode=mode, - indices=bad_indices, - sfreq=sfreq, - cwt_freqs=cwt_freqs, + data, method=method, mode=mode, indices=bad_indices, cwt_freqs=cwt_freqs ) # check bad fmin/fmax caught @@ -915,7 +878,6 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): method=method, mode=mode, indices=indices, - sfreq=sfreq, fmin=(10.0, 15.0), fmax=(15.0, 20.0), cwt_freqs=cwt_freqs, @@ -937,22 +899,20 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): @pytest.mark.parametrize("method", ["mic", "mim", "gc", "gc_tr"]) def test_multivar_spectral_connectivity_parallel(method): """Test multivar. freq.-domain connectivity methods run in parallel.""" - sfreq = 50.0 - n_signals = 4 # Do not change! - n_epochs = 8 - n_times = 256 - rng = np.random.RandomState(0) - data = rng.randn(n_epochs, n_signals, n_times) + data = make_signals_in_freq_bands( + n_seeds=2, # do not change! + n_targets=2, # do not change! + freq_band=(10, 20), # arbitrary for this test + n_epochs=8, + n_times=256, + sfreq=50, + rng_seed=0, + ) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) spectral_connectivity_epochs( - data, - method=method, - mode="multitaper", - indices=indices, - sfreq=sfreq, - gc_n_lags=10, - n_jobs=2, + data, method=method, mode="multitaper", indices=indices, gc_n_lags=10, n_jobs=2 ) spectral_connectivity_time( data, @@ -960,7 +920,6 @@ def test_multivar_spectral_connectivity_parallel(method): method=method, mode="multitaper", indices=indices, - sfreq=sfreq, gc_n_lags=10, n_jobs=2, ) @@ -968,12 +927,16 @@ def test_multivar_spectral_connectivity_parallel(method): def test_multivar_spectral_connectivity_flipped_indices(): """Test multivar. indices structure maintained by connectivity methods.""" - sfreq = 50.0 - n_signals = 4 - n_epochs = 8 - n_times = 256 - rng = np.random.RandomState(0) - data = rng.randn(n_epochs, n_signals, n_times) + data = make_signals_in_freq_bands( + n_seeds=2, # do not change! + n_targets=2, # do not change! + freq_band=(10, 20), # arbitrary for this test + n_epochs=8, + n_times=256, + sfreq=50, + rng_seed=0, + ) + freqs = np.arange(10, 20) # if we're not careful, when finding the channels we need to compute the @@ -986,26 +949,26 @@ def test_multivar_spectral_connectivity_flipped_indices(): method = "gc" con_st = spectral_connectivity_epochs( # seed -> target - data, method=method, indices=indices, sfreq=sfreq, gc_n_lags=10 + data, method=method, indices=indices, gc_n_lags=10 ) con_ts = spectral_connectivity_epochs( # target -> seed - data, method=method, indices=flipped_indices, sfreq=sfreq, gc_n_lags=10 + data, method=method, indices=flipped_indices, gc_n_lags=10 ) con_st_ts = spectral_connectivity_epochs( # seed -> target; target -> seed - data, method=method, indices=concat_indices, sfreq=sfreq, gc_n_lags=10 + data, method=method, indices=concat_indices, gc_n_lags=10 ) assert not np.all(con_st.get_data() == con_ts.get_data()) assert np.all(con_st.get_data()[0] == con_st_ts.get_data()[0]) assert np.all(con_ts.get_data()[0] == con_st_ts.get_data()[1]) con_st = spectral_connectivity_time( # seed -> target - data, freqs, method=method, indices=indices, sfreq=sfreq, gc_n_lags=10 + data, freqs, method=method, indices=indices, gc_n_lags=10 ) con_ts = spectral_connectivity_time( # target -> seed - data, freqs, method=method, indices=flipped_indices, sfreq=sfreq, gc_n_lags=10 + data, freqs, method=method, indices=flipped_indices, gc_n_lags=10 ) con_st_ts = spectral_connectivity_time( # seed -> target; target -> seed - data, freqs, method=method, indices=concat_indices, sfreq=sfreq, gc_n_lags=10 + data, freqs, method=method, indices=concat_indices, gc_n_lags=10 ) assert not np.all(con_st.get_data() == con_ts.get_data()) assert np.all(con_st.get_data()[:, 0] == con_st_ts.get_data()[:, 0]) @@ -1451,12 +1414,17 @@ def test_spectral_connectivity_time_padding(method, mode, padding): @pytest.mark.parametrize("faverage", [True, False]) def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): """Test result shapes of time-resolved multivar. connectivity methods.""" - sfreq = 50.0 - n_signals = 4 # Do not change! n_epochs = 8 - n_times = 500 - rng = np.random.RandomState(0) - data = rng.randn(n_epochs, n_signals, n_times) + data = make_signals_in_freq_bands( + n_seeds=2, # do not change! + n_targets=2, # do not change! + freq_band=(10, 20), # arbitrary for this test + n_epochs=n_epochs, + n_times=256, + sfreq=50, + rng_seed=0, + ) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) n_cons = len(indices[0]) freqs = np.arange(10, 25 + 1) @@ -1475,7 +1443,6 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): freqs, indices=indices, method=method, - sfreq=sfreq, faverage=faverage, average=average, gc_n_lags=10, @@ -1504,7 +1471,6 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): freqs, indices=indices, method=method, - sfreq=sfreq, faverage=faverage, average=average, gc_n_lags=10, @@ -1527,11 +1493,19 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): @pytest.mark.parametrize("mode", ["multitaper", "cwt_morlet"]) def test_multivar_spectral_connectivity_time_error_catch(method, mode): """Test error catching for time-resolved multivar. connectivity methods.""" - sfreq = 50.0 - n_signals = 4 # Do not change! - n_epochs = 8 - n_times = 256 - data = np.random.rand(n_epochs, n_signals, n_times) + n_seeds = 2 # do not change! + n_targets = 2 # do not change! + n_signals = n_seeds + n_targets + data = make_signals_in_freq_bands( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=(10, 20), # arbitrary for this test + n_epochs=8, + n_times=256, + sfreq=50, + rng_seed=0, + ) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) freqs = np.arange(10, 25 + 1) @@ -1545,12 +1519,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): ): non_nested_indices = (np.array([0, 1]), np.array([2, 3])) spectral_connectivity_time( - data, - freqs, - method=method, - mode=mode, - indices=non_nested_indices, - sfreq=sfreq, + data, freqs, method=method, mode=mode, indices=non_nested_indices ) # check bad indices with repeated channels caught @@ -1559,66 +1528,42 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): ): repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) spectral_connectivity_time( - data, freqs, method=method, mode=mode, indices=repeated_indices, sfreq=sfreq + data, freqs, method=method, mode=mode, indices=repeated_indices ) # check mixed methods caught with pytest.raises(ValueError, match="bivariate and multivariate connectivity"): mixed_methods = [method, "coh"] spectral_connectivity_time( - data, freqs, method=mixed_methods, mode=mode, indices=indices, sfreq=sfreq + data, freqs, method=mixed_methods, mode=mode, indices=indices ) # check bad rank args caught too_low_rank = (np.array([0]), np.array([0])) with pytest.raises(ValueError, match="ranks for seeds and targets must be"): spectral_connectivity_time( - data, - freqs, - method=method, - indices=indices, - sfreq=sfreq, - mode=mode, - rank=too_low_rank, + data, freqs, method=method, indices=indices, mode=mode, rank=too_low_rank ) too_high_rank = (np.array([3]), np.array([3])) with pytest.raises(ValueError, match="ranks for seeds and targets must be"): spectral_connectivity_time( - data, - freqs, - method=method, - indices=indices, - sfreq=sfreq, - mode=mode, - rank=too_high_rank, + data, freqs, method=method, indices=indices, mode=mode, rank=too_high_rank ) too_few_rank = ([], []) with pytest.raises(ValueError, match="rank argument must have shape"): spectral_connectivity_time( - data, - freqs, - method=method, - indices=indices, - sfreq=sfreq, - mode=mode, - rank=too_few_rank, + data, freqs, method=method, indices=indices, mode=mode, rank=too_few_rank ) too_much_rank = (np.array([2, 2]), np.array([2, 2])) with pytest.raises(ValueError, match="rank argument must have shape"): spectral_connectivity_time( - data, - freqs, - method=method, - indices=indices, - sfreq=sfreq, - mode=mode, - rank=too_much_rank, + data, freqs, method=method, indices=indices, mode=mode, rank=too_much_rank ) # check all-to-all conn. computed for MIC/MIM when no indices given if method in ["mic", "mim"]: con = spectral_connectivity_time( - data, freqs, method=method, indices=None, sfreq=sfreq, mode=mode + data, freqs, method=method, indices=None, mode=mode ) assert con.indices is None assert con.n_nodes == n_signals @@ -1629,7 +1574,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): # check no indices caught with pytest.raises(ValueError, match="indices must be specified"): spectral_connectivity_time( - data, freqs, method=method, mode=mode, indices=None, sfreq=sfreq + data, freqs, method=method, mode=mode, indices=None ) # check intersecting indices caught @@ -1638,7 +1583,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): ValueError, match="seed and target indices must not intersect" ): spectral_connectivity_time( - data, freqs, method=method, mode=mode, indices=bad_indices, sfreq=sfreq + data, freqs, method=method, mode=mode, indices=bad_indices ) # check bad fmin/fmax caught @@ -1649,7 +1594,6 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): method=method, mode=mode, indices=indices, - sfreq=sfreq, fmin=(5.0, 15.0), fmax=(15.0, 30.0), ) @@ -1657,14 +1601,15 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): def test_save(tmp_path): """Test saving results of spectral connectivity.""" - rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq, f = 10, 3, 2000, 1000.0, 20.0 - data = rng.randn(n_epochs, n_chs, n_times) - sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000) - data[:, :, 500:1500] += sig - info = create_info(n_chs, sfreq, "eeg") - tmin = -1 - epochs = EpochsArray(data, info, tmin=tmin) + epochs = make_signals_in_freq_bands( + n_seeds=2, + n_targets=1, + freq_band=(18, 22), # arbitrary for this test + n_epochs=10, + n_times=2000, + sfreq=1000, + rng_seed=0, + ) conn = spectral_connectivity_epochs( epochs, fmin=(4, 8, 13, 30), fmax=(8, 13, 30, 45), faverage=True @@ -1674,14 +1619,16 @@ def test_save(tmp_path): def test_multivar_save_load(tmp_path): """Test saving and loading results of multivariate connectivity.""" - rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq, f = 5, 4, 2000, 1000.0, 20.0 - data = rng.randn(n_epochs, n_chs, n_times) - sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000) - data[:, :, 500:1500] += sig - info = create_info(n_chs, sfreq, "eeg") - tmin = -1 - epochs = EpochsArray(data, info, tmin=tmin) + epochs = make_signals_in_freq_bands( + n_seeds=2, + n_targets=2, + freq_band=(18, 22), # arbitrary for this test + n_epochs=5, + n_times=2000, + sfreq=1000, + rng_seed=0, + ) + tmp_file = os.path.join(tmp_path, "foo_mvc.nc") non_ragged_indices = (np.array([[0, 1]]), np.array([[2, 3]])) @@ -1691,7 +1638,6 @@ def test_multivar_save_load(tmp_path): epochs, method=["mic", "mim", "gc", "gc_tr"], indices=indices, - sfreq=sfreq, fmin=10, fmax=30, ) @@ -1719,22 +1665,24 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): be None, otherwise, `indices` should be a tuple. The type of `indices` and its values should be retained after saving and reloading. """ - rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq = 5, 4, 200, 100.0 - data = rng.randn(n_epochs, n_chs, n_times) - info = create_info(n_chs, sfreq, "eeg") - tmin = -1 - epochs = EpochsArray(data, info, tmin=tmin) + epochs = make_signals_in_freq_bands( + n_seeds=2, + n_targets=2, + freq_band=(18, 22), # arbitrary for this test + n_epochs=10, + n_times=200, + sfreq=100, + rng_seed=0, + ) + freqs = np.arange(10, 31) tmp_file = os.path.join(tmp_path, "foo_mvc.nc") # test the pair of method and indices defined to check the output indices con_epochs = spectral_connectivity_epochs( - epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30 - ) - con_time = spectral_connectivity_time( - epochs, freqs, method=method, indices=indices, sfreq=sfreq + epochs, method=method, indices=indices, fmin=10, fmax=30 ) + con_time = spectral_connectivity_time(epochs, freqs, method=method, indices=indices) for con in [con_epochs, con_time]: con.save(tmp_file) @@ -1760,12 +1708,16 @@ def test_multivar_spectral_connectivity_indices_roundtrip_io(tmp_path, method, i be None, otherwise, `indices` should be a tuple. The type of `indices` and its values should be retained after saving and reloading. """ - rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq = 5, 4, 200, 100.0 - data = rng.randn(n_epochs, n_chs, n_times) - info = create_info(n_chs, sfreq, "eeg") - tmin = -1 - epochs = EpochsArray(data, info, tmin=tmin) + epochs = make_signals_in_freq_bands( + n_seeds=2, + n_targets=2, + freq_band=(18, 22), # arbitrary for this test + n_epochs=10, + n_times=200, + sfreq=100, + rng_seed=0, + ) + freqs = np.arange(10, 31) tmp_file = os.path.join(tmp_path, "foo_mvc.nc") @@ -1775,16 +1727,10 @@ def test_multivar_spectral_connectivity_indices_roundtrip_io(tmp_path, method, i pytest.skip() con_epochs = spectral_connectivity_epochs( - epochs, - method=method, - indices=indices, - sfreq=sfreq, - fmin=10, - fmax=30, - gc_n_lags=10, + epochs, method=method, indices=indices, fmin=10, fmax=30, gc_n_lags=10 ) con_time = spectral_connectivity_time( - epochs, freqs, method=method, indices=indices, sfreq=sfreq, gc_n_lags=10 + epochs, freqs, method=method, indices=indices, gc_n_lags=10 ) for con in [con_epochs, con_time]: From fb2e5804724f32137085b06d2e70f38e1566c357 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 4 Mar 2024 20:15:09 +0100 Subject: [PATCH 51/59] Apply suggestions from code review Co-authored-by: Adam Li --- mne_connectivity/datasets/frequency.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/mne_connectivity/datasets/frequency.py b/mne_connectivity/datasets/frequency.py index 7f688184..91f36750 100644 --- a/mne_connectivity/datasets/frequency.py +++ b/mne_connectivity/datasets/frequency.py @@ -43,7 +43,8 @@ def make_signals_in_freq_bands( Sampling frequency of the simulated data, in Hz. trans_bandwidth : int | float (default 1) Transition bandwidth of the filter to apply to isolate activity in - ``freq_band``, in Hz. + ``freq_band``, in Hz. These are passed to the ``l_bandwidth`` and ``h_bandwidth`` + keyword arguments in :func:`mne.filter.create_filter`. snr : float (default 0.7) Signal-to-noise ratio of the simulated data in the range [0, 1]. connection_delay : int (default 5) @@ -54,7 +55,8 @@ def make_signals_in_freq_bands( Earliest time of each epoch. ch_names : list of str | None (default None) Names of the channels in the simulated data. If `None`, the channels are named - according to their index and the frequency band of interaction. + according to their index and the frequency band of interaction. If specified, must be a list of + ``n_seeds + n_targets`` channel names. ch_types : str | list of str (default "eeg") Types of the channels in the simulated data. rng_seed : int | None (default None) @@ -62,7 +64,7 @@ def make_signals_in_freq_bands( Returns ------- - epochs : mne.EpochsArray + epochs : mne.EpochsArray of shape (n_epochs, n_seeds + n_targets, n_times) The simulated data stored in an `mne.EpochsArray` object. The channels are arranged according to seeds, then targets. @@ -100,10 +102,10 @@ def make_signals_in_freq_bands( ) # simulate data - rng = np.random.RandomState(rng_seed) + rng = np.random.default_rng(rng_seed) # simulate signal source at desired frequency band - signal = rng.randn(1, n_epochs * n_times + np.abs(connection_delay)) + signal = rng.standard_normal(size=(1, n_epochs * n_times + np.abs(connection_delay))) signal = filter_data( data=signal, sfreq=sfreq, @@ -115,7 +117,7 @@ def make_signals_in_freq_bands( ) # simulate noise for each channel - noise = rng.randn(n_channels, n_epochs * n_times + np.abs(connection_delay)) + noise = rng.standard_normal(size=(n_channels, n_epochs * n_times + np.abs(connection_delay))) # create data by projecting signal into each channel of noise data = (signal * snr) + (noise * (1 - snr)) From 358004a9c61d37eefc3bc01e0abbd40e422fa0f7 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 4 Mar 2024 20:16:45 +0100 Subject: [PATCH 52/59] Add suggestion from code review --- mne_connectivity/datasets/frequency.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne_connectivity/datasets/frequency.py b/mne_connectivity/datasets/frequency.py index 7f688184..6cc73758 100644 --- a/mne_connectivity/datasets/frequency.py +++ b/mne_connectivity/datasets/frequency.py @@ -56,7 +56,8 @@ def make_signals_in_freq_bands( Names of the channels in the simulated data. If `None`, the channels are named according to their index and the frequency band of interaction. ch_types : str | list of str (default "eeg") - Types of the channels in the simulated data. + Types of the channels in the simulated data. If specified as a list, must be a + list of ``n_seeds + n_targets`` channel names. rng_seed : int | None (default None) Seed to use for the random number generator. If `None`, no seed is specified. From 5509a28d7b6ebd5a81d2d514fdc1da0d1b84653b Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 4 Mar 2024 20:18:45 +0100 Subject: [PATCH 53/59] Fix black formatting --- mne_connectivity/datasets/frequency.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mne_connectivity/datasets/frequency.py b/mne_connectivity/datasets/frequency.py index 62875ecd..153bafb7 100644 --- a/mne_connectivity/datasets/frequency.py +++ b/mne_connectivity/datasets/frequency.py @@ -106,7 +106,9 @@ def make_signals_in_freq_bands( rng = np.random.default_rng(rng_seed) # simulate signal source at desired frequency band - signal = rng.standard_normal(size=(1, n_epochs * n_times + np.abs(connection_delay))) + signal = rng.standard_normal( + size=(1, n_epochs * n_times + np.abs(connection_delay)) + ) signal = filter_data( data=signal, sfreq=sfreq, @@ -118,7 +120,9 @@ def make_signals_in_freq_bands( ) # simulate noise for each channel - noise = rng.standard_normal(size=(n_channels, n_epochs * n_times + np.abs(connection_delay))) + noise = rng.standard_normal( + size=(n_channels, n_epochs * n_times + np.abs(connection_delay)) + ) # create data by projecting signal into each channel of noise data = (signal * snr) + (noise * (1 - snr)) From 46d40cdaa620dd5d9d1791682ed44a4cc38b0e3f Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 4 Mar 2024 20:28:57 +0100 Subject: [PATCH 54/59] Fix thresholds for new RNG settings --- mne_connectivity/tests/test_datasets.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mne_connectivity/tests/test_datasets.py b/mne_connectivity/tests/test_datasets.py index 1b7c2398..4ae7d5ac 100644 --- a/mne_connectivity/tests/test_datasets.py +++ b/mne_connectivity/tests/test_datasets.py @@ -65,11 +65,11 @@ def test_make_signals_in_freq_bands(n_seeds, n_targets, snr, connection_delay, m thresh_bad["coh"] = (0.0, 0.2) # ImCoh if connection_delay == 0: - thresh_good["imcoh"] = (0.0, 0.2) - thresh_bad["imcoh"] = (0.0, 0.2) + thresh_good["imcoh"] = (0.0, 0.17) + thresh_bad["imcoh"] = (0.0, 0.17) else: - thresh_good["imcoh"] = (0.2, 0.8) - thresh_bad["imcoh"] = (0.0, 0.2) + thresh_good["imcoh"] = (0.17, 0.8) + thresh_bad["imcoh"] = (0.0, 0.17) # DPLI if connection_delay == 0: thresh_good["dpli"] = (0.3, 0.6) From e51d1303619fcd2c760a0616239af30e64e0f283 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 4 Mar 2024 20:42:35 +0100 Subject: [PATCH 55/59] Fix missing object references --- mne_connectivity/datasets/frequency.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mne_connectivity/datasets/frequency.py b/mne_connectivity/datasets/frequency.py index 153bafb7..8fd29308 100644 --- a/mne_connectivity/datasets/frequency.py +++ b/mne_connectivity/datasets/frequency.py @@ -43,8 +43,8 @@ def make_signals_in_freq_bands( Sampling frequency of the simulated data, in Hz. trans_bandwidth : int | float (default 1) Transition bandwidth of the filter to apply to isolate activity in - ``freq_band``, in Hz. These are passed to the ``l_bandwidth`` and ``h_bandwidth`` - keyword arguments in :func:`mne.filter.create_filter`. + ``freq_band``, in Hz. These are passed to the ``l_bandwidth`` and + ``h_bandwidth`` keyword arguments in :func:`mne.filter.create_filter`. snr : float (default 0.7) Signal-to-noise ratio of the simulated data in the range [0, 1]. connection_delay : int (default 5) @@ -55,8 +55,8 @@ def make_signals_in_freq_bands( Earliest time of each epoch. ch_names : list of str | None (default None) Names of the channels in the simulated data. If `None`, the channels are named - according to their index and the frequency band of interaction. If specified, must be a list of - ``n_seeds + n_targets`` channel names. + according to their index and the frequency band of interaction. If specified, + must be a list of ``n_seeds + n_targets`` channel names. ch_types : str | list of str (default "eeg") Types of the channels in the simulated data. If specified as a list, must be a list of ``n_seeds + n_targets`` channel names. @@ -65,7 +65,7 @@ def make_signals_in_freq_bands( Returns ------- - epochs : mne.EpochsArray of shape (n_epochs, n_seeds + n_targets, n_times) + epochs : mne.EpochsArray of shape (n_epochs, ``n_seeds + n_targets``, n_times) The simulated data stored in an `mne.EpochsArray` object. The channels are arranged according to seeds, then targets. From f87edae9c5b2697d14ed80a2fb038ce40e06be39 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 12 Mar 2024 17:02:20 +0100 Subject: [PATCH 56/59] Refactor multivariate methods for connectivity classes --- mne_connectivity/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne_connectivity/base.py b/mne_connectivity/base.py index fdb03884..89f30bca 100644 --- a/mne_connectivity/base.py +++ b/mne_connectivity/base.py @@ -17,6 +17,7 @@ warn, ) +from mne_connectivity.spectral.epochs_multivariate import _multivariate_methods from mne_connectivity.utils import _prepare_xarray_mne_data_structures, fill_doc from mne_connectivity.viz import plot_connectivity_circle @@ -725,7 +726,7 @@ def get_data(self, output="compact"): if output == "raveled": data = self._data else: - if self.method in ["mic", "mim", "gc", "gc_tr"]: + if self.method in _multivariate_methods: # multivariate results cannot be returned in a dense form as a # single set of results would correspond to multiple entries in # the matrix, and there could also be cases where multiple From 00568203b863055bd65aad92d663673aef3ae642 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 12 Mar 2024 17:06:46 +0100 Subject: [PATCH 57/59] Revert "Refactor multivariate methods for connectivity classes" This reverts commit f87edae9c5b2697d14ed80a2fb038ce40e06be39. --- mne_connectivity/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mne_connectivity/base.py b/mne_connectivity/base.py index 89f30bca..fdb03884 100644 --- a/mne_connectivity/base.py +++ b/mne_connectivity/base.py @@ -17,7 +17,6 @@ warn, ) -from mne_connectivity.spectral.epochs_multivariate import _multivariate_methods from mne_connectivity.utils import _prepare_xarray_mne_data_structures, fill_doc from mne_connectivity.viz import plot_connectivity_circle @@ -726,7 +725,7 @@ def get_data(self, output="compact"): if output == "raveled": data = self._data else: - if self.method in _multivariate_methods: + if self.method in ["mic", "mim", "gc", "gc_tr"]: # multivariate results cannot be returned in a dense form as a # single set of results would correspond to multiple entries in # the matrix, and there could also be cases where multiple From 2d756fbbeb0612565fb1fb50fe1cb5fc5e467224 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 12 Mar 2024 17:07:15 +0100 Subject: [PATCH 58/59] Add CaCoh as multivariate method to classes --- mne_connectivity/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/base.py b/mne_connectivity/base.py index fdb03884..e38fd2c2 100644 --- a/mne_connectivity/base.py +++ b/mne_connectivity/base.py @@ -725,7 +725,7 @@ def get_data(self, output="compact"): if output == "raveled": data = self._data else: - if self.method in ["mic", "mim", "gc", "gc_tr"]: + if self.method in ["cacoh", "mic", "mim", "gc", "gc_tr"]: # multivariate results cannot be returned in a dense form as a # single set of results would correspond to multiple entries in # the matrix, and there could also be cases where multiple From 95cd343978f1c0f03dd1dbd8df876fe886519ca0 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 18 Mar 2024 23:24:56 +0100 Subject: [PATCH 59/59] Update examples to use simulation function --- examples/cacoh.py | 99 ++++------------------ examples/compare_coherency_methods.py | 114 ++++++-------------------- 2 files changed, 43 insertions(+), 170 deletions(-) diff --git a/examples/cacoh.py b/examples/cacoh.py index 22a255a7..9f3d2838 100644 --- a/examples/cacoh.py +++ b/examples/cacoh.py @@ -9,16 +9,19 @@ spatial patterns of the connectivity. """ -# Authors: Mohammad Orabe -# Thomas S. Binns +# Authors: Thomas S. Binns +# Mohammad Orabe # License: BSD (3-clause) # %% import numpy as np from matplotlib import pyplot as plt -import mne -from mne_connectivity import seed_target_indices, spectral_connectivity_epochs +from mne_connectivity import ( + make_signals_in_freq_bands, + seed_target_indices, + spectral_connectivity_epochs, +) ############################################################################### # Background @@ -64,95 +67,29 @@ # # We can consider the seeds and targets to be signals of different modalities, # e.g. cortical EEG signals and subcortical LFP signals, cortical EEG signals -# and muscular EMG signals, etc.... We use the function below to simulate these -# signals. - -# %% - - -def simulate_connectivity(freq_band: tuple[int, int], rng_seed: int) -> np.ndarray: - """Simulates signals interacting in a given frequency band. - - Parameters - ---------- - freq_band : tuple of int, int - Frequency band where the connectivity should be simulated, where the - first entry corresponds to the lower frequency, and the second entry to - the higher frequency. - - rng_seed : int - Seed to use for the random number generator. - - Returns - ------- - data : numpy.ndarray - The simulated data stored in an array. The channels are arranged - according to seeds, then targets. - """ - # Define fixed simulation parameters - n_seeds = 5 - n_targets = 3 - n_epochs = 10 - n_times = 200 # samples - sfreq = 100 # Hz - snr = 0.7 - trans_bandwidth = 1 # Hz - connection_delay = 1 # sample - - np.random.seed(rng_seed) - - n_channels = n_seeds + n_targets - - # simulate signal source at desired frequency band - signal = np.random.randn(1, n_epochs * n_times + connection_delay) - signal = mne.filter.filter_data( - data=signal, - sfreq=sfreq, - l_freq=freq_band[0], - h_freq=freq_band[1], - l_trans_bandwidth=trans_bandwidth, - h_trans_bandwidth=trans_bandwidth, - fir_design="firwin2", - verbose=False, - ) - - # simulate noise for each channel - noise = np.random.randn(n_channels, n_epochs * n_times + connection_delay) - - # create data by projecting signal into noise - data = (signal * snr) + (noise * (1 - snr)) - - # shift target data by desired delay - if connection_delay > 0: - # shift target data - data[n_seeds:, connection_delay:] = data[n_seeds:, : n_epochs * n_times] - # remove extra time - data = data[:, : n_epochs * n_times] - - # reshape data into epochs - data = data.reshape(n_channels, n_epochs, n_times) - data = data.transpose((1, 0, 2)) # (epochs x channels x times) - - return data - - -############################################################################### +# and muscular EMG signals, etc.... We use the +# :func:`~mne_connectivity.make_signals_in_freq_bands` function to simulate +# these signals. # %% # Generate simulated data -data_10_12 = simulate_connectivity( +data_10_12 = make_signals_in_freq_bands( + n_seeds=5, + n_targets=3, freq_band=(10, 12), # 10-12 Hz interaction rng_seed=42, ) -data_23_25 = simulate_connectivity( +data_23_25 = make_signals_in_freq_bands( + n_seeds=5, + n_targets=3, freq_band=(23, 25), # 23-25 Hz interaction rng_seed=44, ) -# Combine data into a single array -data = np.concatenate((data_10_12, data_23_25), axis=1) +# Combine data into a single object +data = data_10_12.add_channels([data_23_25]) ############################################################################### # Computing CaCoh diff --git a/examples/compare_coherency_methods.py b/examples/compare_coherency_methods.py index a823e8d4..70848b6b 100644 --- a/examples/compare_coherency_methods.py +++ b/examples/compare_coherency_methods.py @@ -16,8 +16,11 @@ import numpy as np from matplotlib import pyplot as plt -import mne -from mne_connectivity import seed_target_indices, spectral_connectivity_epochs +from mne_connectivity import ( + make_signals_in_freq_bands, + seed_target_indices, + spectral_connectivity_epochs, +) ############################################################################### # An introduction to coherency-based connectivity methods @@ -86,96 +89,25 @@ # %% - -def simulate_connectivity( - freq_band: tuple[int, int], connection_delay: int, rng_seed: int -) -> np.ndarray: - """Simulates signals interacting in a given frequency band. - - Parameters - ---------- - freq_band : tuple of int, int - Frequency band where the connectivity should be simulated, where the - first entry corresponds to the lower frequency, and the second entry to - the higher frequency. - - connection_delay : - Number of timepoints for the delay of connectivity between the seeds - and targets. If > 0, the target data is a delayed form of the seed data - by this many timepoints. - - rng_seed : int - Seed to use for the random number generator. - - Returns - ------- - data : numpy.ndarray - The simulated data stored in an array. The channels are arranged - according to seeds, then targets. - """ - # Define fixed simulation parameters - n_seeds = 3 - n_targets = 3 - n_epochs = 10 - n_times = 200 # samples - sfreq = 100 # Hz - snr = 0.7 - trans_bandwidth = 1 # Hz - - np.random.seed(rng_seed) - - n_channels = n_seeds + n_targets - - # simulate signal source at desired frequency band - signal = np.random.randn(1, n_epochs * n_times + connection_delay) - signal = mne.filter.filter_data( - data=signal, - sfreq=sfreq, - l_freq=freq_band[0], - h_freq=freq_band[1], - l_trans_bandwidth=trans_bandwidth, - h_trans_bandwidth=trans_bandwidth, - fir_design="firwin2", - verbose=False, - ) - - # simulate noise for each channel - noise = np.random.randn(n_channels, n_epochs * n_times + connection_delay) - - # create data by projecting signal into noise - data = (signal * snr) + (noise * (1 - snr)) - - # shift target data by desired delay - if connection_delay > 0: - # shift target data - data[n_seeds:, connection_delay:] = data[n_seeds:, : n_epochs * n_times] - # remove extra time - data = data[:, : n_epochs * n_times] - - # reshape data into epochs - data = data.reshape(n_channels, n_epochs, n_times) - data = data.transpose((1, 0, 2)) # (epochs x channels x times) - - return data - - -# %% - # Generate simulated data -data_delay = simulate_connectivity( +data_delay = make_signals_in_freq_bands( + n_seeds=3, + n_targets=3, freq_band=(10, 12), # 10-12 Hz interaction connection_delay=2, # samples; non-zero time-lag rng_seed=42, ) -data_no_delay = simulate_connectivity( +data_no_delay = make_signals_in_freq_bands( + n_seeds=3, + n_targets=3, freq_band=(23, 25), # 23-25 Hz interaction connection_delay=0, # samples; zero time-lag rng_seed=44, ) -# Combine data into a single array -data = np.concatenate((data_delay, data_no_delay), axis=1) +# Combine data into a single object +data = data_delay.add_channels([data_no_delay]) ############################################################################### # We compute the connectivity of these simulated signals using CaCoh (a @@ -374,7 +306,7 @@ def plot_connectivity_circle(): # # **In situations where non-physiological zero time-lag interactions are not # assumed, methods based on real and imaginary parts of coherency (Cohy, Coh, -# CaCoh) should be used.** Examples of situations include: +# CaCoh) should be used.** An example includes: # # - Connectivity between channels of different modalities where different # references are used. @@ -411,20 +343,24 @@ def plot_connectivity_circle(): # %% # Generate simulated data -data_10_12 = simulate_connectivity( +data_10_12 = make_signals_in_freq_bands( + n_seeds=3, + n_targets=3, freq_band=(10, 12), # 10-12 Hz interaction connection_delay=1, # samples - rng_seed=42, + rng_seed=40, ) -data_23_25 = simulate_connectivity( - freq_band=(23, 25), # 10-12 Hz interaction +data_23_25 = make_signals_in_freq_bands( + n_seeds=3, + n_targets=3, + freq_band=(23, 25), # 23-25 Hz interaction connection_delay=1, # samples - rng_seed=44, + rng_seed=42, ) # Combine data into a single array -data = np.concatenate((data_10_12, data_23_25), axis=1) +data = data_10_12.add_channels([data_23_25]) # Compute CaCoh & MIC (cacoh, mic) = spectral_connectivity_epochs( @@ -519,7 +455,7 @@ def plot_connectivity_circle(): axis.plot(imcoh.freqs, imcoh_mean_subbed, linewidth=2, label="ImCoh", linestyle="--") axis.set_xlabel("Frequency (Hz)") axis.set_ylabel("Mean-corrected connectivity (A.U.)") -axis.annotate("$\pm$45°\ninteraction", xy=(12, 0.25)) +axis.annotate("$\pm$45°\ninteraction", xy=(13, 0.25)) axis.annotate("$\pm$90°\ninteraction", xy=(26.5, 0.25)) axis.legend(loc="upper left") fig.suptitle("Coh vs. ImCoh\n$\pm$45° & $\pm$90° interactions")