From eda32b5996e6ed9ced88009b069485f95ac11496 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 24 Jul 2023 11:19:05 +0200 Subject: [PATCH 01/40] added support for ragged connections --- doc/api.rst | 2 + examples/granger_causality.py | 32 +- examples/handling_ragged_arrays.py | 154 ++++++++++ examples/mic_mim.py | 43 ++- mne_connectivity/__init__.py | 3 +- mne_connectivity/base.py | 9 +- mne_connectivity/spectral/epochs.py | 273 ++++++++++-------- .../spectral/tests/test_spectral.py | 167 +++++++---- mne_connectivity/spectral/time.py | 167 ++++++----- mne_connectivity/tests/test_utils.py | 73 ++++- mne_connectivity/utils/__init__.py | 3 +- mne_connectivity/utils/utils.py | 95 ++++++ 12 files changed, 726 insertions(+), 295 deletions(-) create mode 100644 examples/handling_ragged_arrays.py diff --git a/doc/api.rst b/doc/api.rst index 26ef14b6..c91f9c02 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -73,7 +73,9 @@ Post-processing on connectivity degree seed_target_indices + seed_target_multivariate_indices check_indices + check_multivariate_indices select_order Visualization functions diff --git a/examples/granger_causality.py b/examples/granger_causality.py index f5d8316d..64a657db 100644 --- a/examples/granger_causality.py +++ b/examples/granger_causality.py @@ -145,7 +145,7 @@ ############################################################################### # We will focus on connectivity between sensors over the parietal and occipital -# cortices, with 20 parietal sensors designated as group A, and 20 occipital +# cortices, with 20 parietal sensors designated as group A, and 22 occipital # sensors designated as group B. # %% @@ -157,17 +157,8 @@ signals_b = [idx for idx, ch_info in enumerate(epochs.info['chs']) if ch_info['ch_name'][2] == 'O'] -# XXX: Currently ragged indices are not supported, so we only consider a single -# list of indices with an equal number of seeds and targets -min_n_chs = min(len(signals_a), len(signals_b)) -signals_a = signals_a[:min_n_chs] -signals_b = signals_b[:min_n_chs] - -indices_ab = (np.array(signals_a), np.array(signals_b)) # A => B -indices_ba = (np.array(signals_b), np.array(signals_a)) # B => A - -signals_a_names = [epochs.info['ch_names'][idx] for idx in signals_a] -signals_b_names = [epochs.info['ch_names'][idx] for idx in signals_b] +indices_ab = (np.array([signals_a]), np.array([signals_b])) # A => B +indices_ba = (np.array([signals_b]), np.array([signals_a])) # B => A # compute Granger causality gc_ab = spectral_connectivity_epochs( @@ -181,8 +172,8 @@ ############################################################################### # Plotting the results, we see that there is a flow of information from our -# parietal sensors (group A) to our occipital sensors (group B) with noticeable -# peaks at around 8, 18, and 26 Hz. +# parietal sensors (group A) to our occipital sensors (group B) with a +# noticeable peak at ~8 Hz, and smaller peaks at 18 and 26 Hz. # %% @@ -208,8 +199,7 @@ # # Doing so, we see that the flow of information across the spectrum remains # dominant from parietal to occipital sensors (indicated by the positive-valued -# Granger scores). However, the pattern of connectivity is altered, such as -# around 10 and 12 Hz where peaks of net information flow are now present. +# Granger scores), with similar peaks around 10, 18, and 26 Hz. # %% @@ -289,8 +279,8 @@ # Plotting the TRGC results, reveals a very different picture compared to net # GC. For one, there is now a dominance of information flow ~6 Hz from # occipital to parietal sensors (indicated by the negative-valued Granger -# scores). Additionally, the peaks ~10 Hz are less dominant in the spectrum, -# with parietal to occipital information flow between 13-20 Hz being much more +# scores). Additionally, the peak ~10 Hz is less dominant in the spectrum, with +# parietal to occipital information flow between 13-20 Hz being much more # prominent. The stark difference between net GC and TRGC results indicates # that the net GC spectrum was contaminated by spurious connectivity resulting # from source mixing or correlated noise in the recordings. Altogether, the use @@ -366,8 +356,8 @@ # gets the singular values of the data s = np.linalg.svd(raw.get_data(), compute_uv=False) -# finds how many singular values are "close" to the largest singular value -rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the "closeness" criteria +# finds how many singular values are 'close' to the largest singular value +rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the 'closeness' criteria ############################################################################### # Nonethless, even in situations where you specify an appropriate rank, it is @@ -387,7 +377,7 @@ try: spectral_connectivity_epochs( epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=None, - gc_n_lags=20) # A => B + gc_n_lags=20, verbose=False) # A => B print('Success!') except RuntimeError as error: print('\nCaught the following error:\n' + repr(error)) diff --git a/examples/handling_ragged_arrays.py b/examples/handling_ragged_arrays.py new file mode 100644 index 00000000..67f06a4e --- /dev/null +++ b/examples/handling_ragged_arrays.py @@ -0,0 +1,154 @@ +""" +========================================================= +Working with ragged indices for multivariate connectivity +========================================================= + +This example demonstrates how multivariate connectivity involving different +numbers of seeds and targets can be handled in MNE-Connectivity. +""" + +# Author: Thomas S. Binns +# License: BSD (3-clause) + +# %% + +import numpy as np + +from mne_connectivity import spectral_connectivity_epochs + +############################################################################### +# Background +# ---------- +# +# With multivariate connectivity, interactions between multiple signals can be +# considered together, and the number of signals designated as seeds and +# targets does not have to be equal within or across connections. Issues can +# arise from this when storing information associated with connectivity in +# arrays, as the number of entries within each dimension can vary within and +# across connections depending on the number of seeds and targets. Such arrays +# are 'ragged', and support for ragged arrays is limited in NumPy to the +# ``object`` datatype. Not only is working with ragged arrays is cumbersome, +# but saving arrays with ``dtype='object'`` is not supported by the h5netcdf +# engine used to save connectivity objects. The workaround used in +# MNE-Connectivity is to pad ragged arrays with some known values according to +# the largest number of entries in each dimension, such that there is an equal +# amount of information across and within connections for each dimension of the +# arrays. +# +# As an example, consider we have 5 channels and want to compute 2 connections: +# the first between channels in indices 0 and 1 with those in indices 2, 3, +# and 4; and the second between channels 0, 1, 2, and 3 with channel 4. The +# seed and target indices can be written as such:: +# +# seeds = [[0, 1 ], [0, 1, 2, 3]] +# targets = [[2, 3, 4], [4 ]] +# +# The ``indices`` parameter passed to +# :func:`~mne_connectivity.spectral_connectivity_epochs` and +# :func:`~mne_connectivity.spectral_connectivity_time` must be a tuple of +# array-likes, meaning +# that the indices can be passed as a tuple of: lists; tuples; or NumPy arrays. +# Examples of how ``indices`` can be formed are shown below:: +# +# # tuple of lists +# ragged_indices = ([[0, 1 ], [0, 1, 2, 3]], +# [[2, 3, 4], [4 ]]) +# +# # tuple of tuples +# ragged_indices = (((0, 1 ), (0, 1, 2, 3)), +# ((2, 3, 4), (4 ))) +# +# # tuple of arrays +# ragged_indices = (np.array([[0, 1 ], [0, 1, 2, 3]], dtype='object'), +# np.array([[2, 3, 4], [4 ]], dtype='object')) +# +# **N.B. Note that when forming ragged arrays in NumPy, dtype='object' must be +# specified.** +# +# Just as for bivariate connectivity, the length of ``indices[0]`` and +# ``indices[1]`` is equal (i.e. the number of connections), however information +# about the multiple channel indices for each connection is stored in a nested +# array. Importantly, these indices are ragged, as the first connection will be +# computed between 2 seed and 3 target channels, and the second connection +# between 4 seed and 1 target channel. The connectivity functions will +# recognise the indices as being ragged, and pad them accordingly to make them +# easier to work with and compatible with the h5netcdf saving engine. The known +# value used to pad the arrays is ``-1``, an invalid channel index. The above +# indices would be padded to:: +# +# padded_indices = (np.array([[0, 1, -1, -1], [0, 1, 2, 3]]), +# np.array([[2, 3, 4, -1], [4, -1, -1, -1]])) +# +# These indices are what is stored in the connectivity object, and is also the +# format of indices returned from the helper functions +# :func:`~mne_connectivity.check_multivariate_indices` and +# :func:`~mne_connectivity.seed_target_multivariate_indices`. It is also +# possible to pass the padded indices to the connectivity functions directly. +# +# For the connectivity results themselves, the methods available in +# MNE-Connectivity combine information across the different channels into a +# single (time-)frequency-resolved connectivity spectrum, regardless of the +# number of seed and target channels, so ragged arrays are not a concern here. +# However, the maximised imaginary part of coherency (MIC) method also returns +# spatial patterns of connectivity, which show the contribution of each channel +# to the dimensionality-reduced connectivity estimate (explained in more detail +# in :doc:`mic_mim`). Because these patterns are returned for each channel, +# their shape can vary depending on the number of seeds and targets in each +# connection, making them ragged. To avoid this, the patterns are padded along +# the channel axis with the known and invalid entry ``np.nan``, in line with +# that applied to ``indices``. Extracting only the valid spatial patterns from +# the connectivity object is trivial, as shown below: + +# %% + +# create random data +data = np.random.randn(10, 5, 200) # epochs x channels x times +sfreq = 50 +ragged_indices = ([[0, 1], [0, 1, 2, 3]], # seeds + [[2, 3, 4], [4]]) # targets + +# compute connectivity +con = spectral_connectivity_epochs( + data, method='mic', indices=ragged_indices, sfreq=sfreq, fmin=10, fmax=30, + verbose=False) +patterns = np.array(con.attrs['patterns']) +padded_indices = con.indices +n_freqs = con.get_data().shape[-1] +n_cons = len(ragged_indices[0]) +max_n_chans = max( + [len(inds) for inds in ([*ragged_indices[0], *ragged_indices[1]])]) + +# show that the padded indices entries are all -1 +assert np.count_nonzero(padded_indices[0][0] == -1) == 2 # 2 padded channels +assert np.count_nonzero(padded_indices[1][0] == -1) == 1 # 1 padded channels +assert np.count_nonzero(padded_indices[0][1] == -1) == 0 # 0 padded channels +assert np.count_nonzero(padded_indices[1][1] == -1) == 3 # 3 padded channels + +# patterns have shape [seeds/targets x cons x max channels x freqs (x times)] +assert patterns.shape == (2, n_cons, max_n_chans, n_freqs) + +# show that the padded patterns entries are all np.nan +assert np.all(np.isnan(patterns[0, 0, 2:])) # 2 padded channels +assert np.all(np.isnan(patterns[1, 0, 3:])) # 1 padded channels +assert not np.any(np.isnan(patterns[0, 1])) # 0 padded channels +assert np.all(np.isnan(patterns[1, 1, 1:])) # 3 padded channels + +# extract patterns for first connection using the ragged indices +seed_patterns_con1 = patterns[0, 0, :len(ragged_indices[0][0])] +target_patterns_con1 = patterns[1, 0, :len(ragged_indices[1][0])] + +# extract patterns for second connection using the padded indices (pad = -1) +seed_patterns_con2 = ( + patterns[0, 1, :np.count_nonzero(padded_indices[0][1] != -1)]) +target_patterns_con2 = ( + patterns[1, 1, :np.count_nonzero(padded_indices[1][1] != -1)]) + +# show that shapes of patterns are correct +assert seed_patterns_con1.shape == (2, n_freqs) # channels (0, 1) +assert target_patterns_con1.shape == (3, n_freqs) # channels (2, 3, 4) +assert seed_patterns_con2.shape == (4, n_freqs) # channels (0, 1, 2, 3) +assert target_patterns_con2.shape == (1, n_freqs) # channels (4) + +print('Assertions completed successfully!') + +# %% diff --git a/examples/mic_mim.py b/examples/mic_mim.py index 179ea620..87111586 100644 --- a/examples/mic_mim.py +++ b/examples/mic_mim.py @@ -70,7 +70,7 @@ ############################################################################### # We will focus on connectivity between sensors over the left and right # hemispheres, with 75 sensors in the left hemisphere designated as seeds, and -# 75 sensors in the right hemisphere designated as targets. +# 76 sensors in the right hemisphere designated as targets. # %% @@ -81,13 +81,7 @@ targets = [idx for idx, ch_info in enumerate(epochs.info['chs']) if ch_info['loc'][0] > 0] -# XXX: Currently ragged indices are not supported, so we only consider a single -# list of indices with an equal number of seeds and targets -min_n_chs = min(len(seeds), len(targets)) -seeds = seeds[:min_n_chs] -targets = targets[:min_n_chs] - -multivar_indices = (np.array(seeds), np.array(targets)) +multivar_indices = (np.array([seeds]), np.array([targets])) seed_names = [epochs.info['ch_names'][idx] for idx in seeds] target_names = [epochs.info['ch_names'][idx] for idx in targets] @@ -171,12 +165,11 @@ # # Here, we average across the patterns in the 13-18 Hz range. Plotting the # patterns shows that the greatest connectivity between the left and right -# hemispheres occurs at the posteromedial regions, based on the regions with -# the largest absolute values. Using the signs of the values, we can infer the -# existence of a dipole source in the central regions of the left hemisphere -# which may account for the connectivity contributions seen for the left -# posteromedial and frontolateral areas (represented on the plot as a green -# line). +# hemispheres occurs at the left and right posterior and left central regions, +# based on the areas with the largest absolute values. Using the signs of the +# values, we can infer the existence of a dipole source between the central and +# posterior regions of the left hemisphere accounting for the connectivity +# contributions (represented on the plot as a green line). # %% @@ -185,9 +178,9 @@ fband_idx = [mic.freqs.index(freq) for freq in fband] # patterns have shape [seeds/targets x cons x channels x freqs (x times)] -patterns = np.array(mic.attrs["patterns"]) -seed_pattern = patterns[0] -target_pattern = patterns[1] +patterns = np.array(mic.attrs['patterns']) +seed_pattern = patterns[0, :, :len(seeds)] +target_pattern = patterns[1, :, :len(targets)] # average across frequencies seed_pattern = np.mean(seed_pattern[0, :, fband_idx[0]:fband_idx[1] + 1], axis=1) @@ -217,7 +210,7 @@ # plot the left hemisphere dipole example axes[0].plot( - [-0.1, -0.05], [-0.075, -0.03], color='lime', linewidth=2, + [-0.01, -0.07], [-0.07, -0.03], color='lime', linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='k'), pe.Normal()]) plt.show() @@ -268,7 +261,7 @@ axis.set_ylabel('Absolute connectivity (A.U.)') fig.suptitle('Multivariate interaction measure') -n_channels = len(np.unique([*multivar_indices[0], *multivar_indices[1]])) +n_channels = len(seeds) + len(targets) normalised_mim = mim.get_data()[0] / n_channels print(f'Normalised MIM has a maximum value of {normalised_mim.max():.2f}') @@ -296,7 +289,7 @@ # %% -indices = (np.array([*seeds, *targets]), np.array([*seeds, *targets])) +indices = (np.array([[*seeds, *targets]]), np.array([[*seeds, *targets]])) gim = spectral_connectivity_epochs( epochs, method='mim', indices=indices, fmin=5, fmax=30, rank=None, verbose=False) @@ -307,7 +300,7 @@ axis.set_ylabel('Connectivity (A.U.)') fig.suptitle('Global interaction measure') -n_channels = len(np.unique([*indices[0], *indices[1]])) +n_channels = len(seeds) + len(targets) normalised_gim = gim.get_data()[0] / n_channels print(f'Normalised GIM has a maximum value of {normalised_gim.max():.2f}') @@ -369,9 +362,9 @@ # no. channels equal with and without projecting to rank subspace for patterns assert (patterns[0, 0].shape[0] == - np.array(mic_red.attrs["patterns"])[0, 0].shape[0]) + np.array(mic_red.attrs['patterns'])[0, 0].shape[0]) assert (patterns[1, 0].shape[0] == - np.array(mic_red.attrs["patterns"])[1, 0].shape[0]) + np.array(mic_red.attrs['patterns'])[1, 0].shape[0]) ############################################################################### @@ -392,8 +385,8 @@ # gets the singular values of the data s = np.linalg.svd(raw.get_data(), compute_uv=False) -# finds how many singular values are "close" to the largest singular value -rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the "closeness" criteria +# finds how many singular values are 'close' to the largest singular value +rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the 'closeness' criteria ############################################################################### diff --git a/mne_connectivity/__init__.py b/mne_connectivity/__init__.py index 57aeff7f..c2f03a6c 100644 --- a/mne_connectivity/__init__.py +++ b/mne_connectivity/__init__.py @@ -17,4 +17,5 @@ from .io import read_connectivity from .spectral import spectral_connectivity_time, spectral_connectivity_epochs from .vector_ar import vector_auto_regression, select_order -from .utils import check_indices, degree, seed_target_indices +from .utils import (check_indices, check_multivariate_indices, degree, + seed_target_indices, seed_target_multivariate_indices) diff --git a/mne_connectivity/base.py b/mne_connectivity/base.py index 88951529..672448d1 100644 --- a/mne_connectivity/base.py +++ b/mne_connectivity/base.py @@ -483,7 +483,14 @@ def _prepare_xarray(self, data, names, indices, n_nodes, method, # set method, indices and n_nodes if isinstance(indices, tuple): - new_indices = (list(indices[0]), list(indices[1])) + if all([isinstance(inds, np.ndarray) for inds in indices]): + # leave multivariate indices as arrays for easier indexing + if all([inds.ndim > 1 for inds in indices]): + new_indices = (indices[0], indices[1]) + else: + new_indices = (list(indices[0]), list(indices[1])) + else: + new_indices = (list(indices[0]), list(indices[1])) indices = new_indices kwargs['method'] = method kwargs['indices'] = indices diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index eb766f06..c8499341 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -24,7 +24,7 @@ ProgressBar, _arange_div, _check_option, _time_mask, logger, warn, verbose) from ..base import (SpectralConnectivity, SpectroTemporalConnectivity) -from ..utils import fill_doc, check_indices +from ..utils import fill_doc, check_indices, check_multivariate_indices def _compute_freqs(n_times, sfreq, cwt_freqs, mode): @@ -92,40 +92,40 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, times = times_in[tmin_idx:tmax_idx] n_times = len(times) + if any(this_method in _multivariate_methods for this_method in method): + multivariate_con = True + else: + multivariate_con = False + if indices is None: - if any(this_method in _multivariate_methods for this_method in method): + if multivariate_con: if any(this_method in _gc_methods for this_method in method): raise ValueError( 'indices must be specified when computing Granger ' 'causality, as all-to-all connectivity is not supported') else: logger.info('using all indices for multivariate connectivity') - indices_use = (np.arange(n_signals, dtype=int), - np.arange(n_signals, dtype=int)) + indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], + np.arange(n_signals, dtype=int)[np.newaxis, :]) else: logger.info('only using indices for lower-triangular matrix') # only compute r for lower-triangular region indices_use = np.tril_indices(n_signals, -1) else: - if any(this_method in _gc_methods for this_method in method): - if set(indices[0]).intersection(indices[1]): - raise ValueError( - 'seed and target indices must not intersect when computing' - 'Granger causality') - indices_use = check_indices(indices) + if multivariate_con: + indices_use = check_multivariate_indices(indices) # pad with -1 + if any(this_method in _gc_methods for this_method in method): + for seed, target in zip(indices[0], indices[1]): + intersection = np.intersect1d(seed, target) + if np.any(intersection != -1): # ignore padded entries + raise ValueError( + 'seed and target indices must not intersect when ' + 'computing Granger causality') + else: + indices_use = check_indices(indices) # number of connectivities to compute - if any(this_method in _multivariate_methods for this_method in method): - if ( - len(np.unique(indices_use[0])) != len(indices_use[0]) or - len(np.unique(indices_use[1])) != len(indices_use[1]) - ): - raise ValueError( - 'seed and target indices cannot contain repeated channels for ' - 'multivariate connectivity') - n_cons = 1 # UNTIL RAGGED ARRAYS SUPPORTED - else: - n_cons = len(indices_use[0]) + n_cons = len(indices_use[0]) logger.info(' computing connectivity for %d connections' % n_cons) @@ -189,6 +189,46 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, n_signals, indices_use, warn_times) +def _check_rank_input(rank, data, indices): + """Check the rank argument is appropriate and compute rank if missing.""" + sv_tol = 1e-10 # tolerance for non-zero singular val (rel to largest) + if rank is None: + rank = np.zeros((2, len(indices[0])), dtype=int) + + if isinstance(data, BaseEpochs): + data_arr = data.get_data() + else: + data_arr = data + + for group_i in range(2): # seeds and targets + for con_i, con_idcs in enumerate(indices[group_i]): + con_idcs = con_idcs[con_idcs != -1] # -1 is padded value + s = np.linalg.svd(data_arr[:, con_idcs], compute_uv=False) + rank[group_i][con_i] = np.min( + [np.count_nonzero(epoch >= epoch[0] * sv_tol) + for epoch in s]) + + logger.info('Estimated data ranks:') + con_i = 1 + for seed_rank, target_rank in zip(rank[0], rank[1]): + logger.info(' connection %i - seeds (%i); targets (%i)' + % (con_i, seed_rank, target_rank, )) + con_i += 1 + rank = tuple((np.array(rank[0]), np.array(rank[1]))) + + else: + for seed_idcs, target_idcs, seed_rank, target_rank in zip( + indices[0], indices[1], rank[0], rank[1]): + if not (0 < seed_rank <= len(seed_idcs) and + 0 < target_rank <= len(target_idcs)): + raise ValueError( + 'ranks for seeds and targets must be > 0 and <= the ' + 'number of channels in the seeds and targets, ' + 'respectively, for each connection') + + return rank + + def _assemble_spectral_params(mode, n_times, mt_adaptive, mt_bandwidth, sfreq, mt_low_bias, cwt_n_cycles, cwt_freqs, freqs, freq_mask): @@ -422,22 +462,23 @@ def compute_con(self, indices, ranks, n_epochs=1): if self.name == 'MIC': self.patterns = np.full( - (2, self.n_cons, len(indices[0]), self.n_freqs, n_times), + (2, self.n_cons, indices[0].shape[1], self.n_freqs, n_times), np.nan) con_i = 0 for seed_idcs, target_idcs, seed_rank, target_rank in zip( - [indices[0]], [indices[1]], ranks[0], ranks[1]): + indices[0], indices[1], ranks[0], ranks[1]): self._log_connection_number(con_i) - n_seeds = len(seed_idcs) + seed_idcs = seed_idcs[seed_idcs != -1] + target_idcs = target_idcs[target_idcs != -1] con_idcs = [*seed_idcs, *target_idcs] C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] # Eqs. 32 & 33 C_bar, U_bar_aa, U_bar_bb = self._csd_svd( - C, n_seeds, seed_rank, target_rank) + C, seed_idcs, seed_rank, target_rank) # Eqs. 3 & 4 E = self._compute_e(C_bar, n_seeds=U_bar_aa.shape[3]) @@ -452,10 +493,11 @@ def compute_con(self, indices, ranks, n_epochs=1): self.reshape_results() - def _csd_svd(self, csd, n_seeds, seed_rank, target_rank): + def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): """Dimensionality reduction of CSD with SVD.""" n_times = csd.shape[0] - n_targets = csd.shape[2] - n_seeds + n_seeds = len(seed_idcs) + n_targets = csd.shape[3] - n_seeds C_aa = csd[..., :n_seeds, :n_seeds] C_ab = csd[..., :n_seeds, n_seeds:] @@ -505,8 +547,9 @@ def _compute_e(self, csd, n_seeds): for block_i in ProgressBar( range(self.n_steps), mesg="frequency blocks"): freqs = self._get_block_indices(block_i, self.n_freqs) - parallel(parallel_compute_t( + T[:, freqs] = np.array(parallel(parallel_compute_t( C_r[:, f], T[:, f], n_seeds) for f in freqs) + ).transpose(1, 0, 2, 3) if not np.isreal(T).all() or not np.isfinite(T).all(): raise RuntimeError( @@ -526,6 +569,7 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, U_bar_bb, con_i): """Compute MIC and the associated spatial patterns.""" n_seeds = len(seed_idcs) + n_targets = len(target_idcs) times = np.arange(n_times) freqs = np.arange(self.n_freqs) @@ -564,12 +608,12 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, beta = V_targets[times[:, None], freqs, :, w_targets.argmax(axis=2)] # Eq. 46 (seed spatial patterns) - self.patterns[0, con_i] = (np.matmul( + self.patterns[0, con_i, :n_seeds] = (np.matmul( np.real(C[..., :n_seeds, :n_seeds]), np.matmul(U_bar_aa, np.expand_dims(alpha, axis=3))))[..., 0].T # Eq. 47 (target spatial patterns) - self.patterns[1, con_i] = (np.matmul( + self.patterns[1, con_i, :n_targets] = (np.matmul( np.real(C[..., n_seeds:, n_seeds:]), np.matmul(U_bar_bb, np.expand_dims(beta, axis=3))))[..., 0].T @@ -586,7 +630,7 @@ def _compute_mim(self, E, seed_idcs, target_idcs, con_i): E, E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T # Eq. 15 - if all(np.unique(seed_idcs) == np.unique(target_idcs)): + if np.all(np.unique(seed_idcs) == np.unique(target_idcs)): self.con_scores[con_i] *= 0.5 def reshape_results(self): @@ -598,7 +642,7 @@ def reshape_results(self): def _mic_mim_compute_t(C, T, n_seeds): - """Compute T in place for a single frequency (used for MIC and MIM).""" + """Compute T for a single frequency (used for MIC and MIM).""" for time_i in range(C.shape[0]): T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power( C[time_i, :n_seeds, :n_seeds], -0.5 @@ -607,6 +651,8 @@ def _mic_mim_compute_t(C, T, n_seeds): C[time_i, n_seeds:, n_seeds:], -0.5 ) + return T + class _MICEst(_MultivariateCohEstBase): """Multivariate imaginary part of coherency (MIC) estimator.""" @@ -889,17 +935,16 @@ def compute_con(self, indices, ranks, n_epochs=1): con_i = 0 for seed_idcs, target_idcs, seed_rank, target_rank in zip( - [indices[0]], [indices[1]], ranks[0], ranks[1]): + indices[0], indices[1], ranks[0], ranks[1]): self._log_connection_number(con_i) + seed_idcs = seed_idcs[seed_idcs != -1] + target_idcs = target_idcs[target_idcs != -1] con_idcs = [*seed_idcs, *target_idcs] - C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] - con_seeds = np.arange(len(seed_idcs)) - con_targets = np.arange(len(target_idcs)) + len(seed_idcs) + C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] - C_bar = self._csd_svd( - C, con_seeds, con_targets, seed_rank, target_rank) + C_bar = self._csd_svd(C, seed_idcs, seed_rank, target_rank) n_signals = seed_rank + target_rank con_seeds = np.arange(seed_rank) con_targets = np.arange(target_rank) + seed_rank @@ -921,13 +966,13 @@ def compute_con(self, indices, ranks, n_epochs=1): self.reshape_results() - def _csd_svd(self, csd, seeds, targets, seed_rank, target_rank): + def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): """Dimensionality reduction of CSD with SVD on the covariance.""" # sum over times and epochs to get cov. from CSD cov = csd.sum(axis=(0, 1)) - n_seeds = len(seeds) - n_targets = len(targets) + n_seeds = len(seed_idcs) + n_targets = csd.shape[3] - n_seeds cov_aa = cov[:n_seeds, :n_seeds] cov_bb = cov[n_seeds:, n_seeds:] @@ -1202,7 +1247,7 @@ def _gc_compute_H(A, C, K, z_k, I_n, I_m): See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: 10.1103/PhysRevE.91.040101, Eq. 4. """ - from scipy import linalg # is this necessary??? + from scipy import linalg # XXX: is this necessary??? H = np.zeros((A.shape[0], C.shape[1], C.shape[1]), dtype=np.complex128) for t in range(A.shape[0]): H[t] = I_n + np.matmul( @@ -1231,16 +1276,15 @@ class _GCTREst(_GCEstBase): def _epoch_spectral_connectivity(data, sig_idx, tmin_idx, tmax_idx, sfreq, method, mode, window_fun, eigvals, wavelets, - freq_mask, mt_adaptive, idx_map, block_size, - psd, accumulate_psd, con_method_types, - con_methods, n_signals, n_signals_use, - n_times, gc_n_lags, accumulate_inplace=True): + freq_mask, mt_adaptive, idx_map, n_cons, + block_size, psd, accumulate_psd, + con_method_types, con_methods, n_signals, + n_signals_use, n_times, gc_n_lags, + accumulate_inplace=True): """Estimate connectivity for one epoch (see spectral_connectivity).""" if any(this_method in _multivariate_methods for this_method in method): - n_cons = 1 # UNTIL RAGGED ARRAYS SUPPORTED n_con_signals = n_signals_use ** 2 else: - n_cons = len(idx_map[0]) n_con_signals = n_cons if wavelets is not None: @@ -1513,10 +1557,11 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, 'mim', 'gc', 'gc_tr]``) cannot be called with the other methods. indices : tuple of array | None Two arrays with indices of connections for which to compute - connectivity. If a multivariate method is called, the indices are for a - single connection between all seeds and all targets. If None, all - connections are computed, unless a Granger causality method is called, - in which case an error is raised. + connectivity. If a multivariate method is called, each array for the + seeds and targets should contain a nested array of channel indices for + the individual connections. If None, connections between all channels + are computed, unless a Granger causality method is called, in which + case an error is raised. sfreq : float The sampling frequency. Required if data is not :class:`Epochs `. @@ -1582,14 +1627,13 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, con : array | list of array Computed connectivity measure(s). Either an instance of ``SpectralConnectivity`` or ``SpectroTemporalConnectivity``. - The shape of each connectivity dataset is either - (n_signals ** 2, n_freqs) mode: 'multitaper' or 'fourier' - (n_signals ** 2, n_freqs, n_times) mode: 'cwt_morlet' - when "indices" is None, or - (n_con, n_freqs) mode: 'multitaper' or 'fourier' - (n_con, n_freqs, n_times) mode: 'cwt_morlet' - when "indices" is specified and "n_con = len(indices[0])". If a - multivariate method is called "n_con = 1" even if "indices" is None. + The shape of each connectivity dataset is either: + (n_cons, n_freqs) mode: 'multitaper' or 'fourier'; or + (n_cons, n_freqs, n_times) mode: 'cwt_morlet'. When "indices" is None + and a bivariate method is called, "n_cons = n_signals ** 2", or if a + multivariate method is called "n_cons = 1". When "indices" is + specified, "n_con = len(indices[0])" for bivariate and multivariate + methods. See Also -------- @@ -1635,13 +1679,19 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, are in the same order as defined indices. For multivariate methods, this is handled differently. If "indices" is - None, connectivity between all signals will attempt to be computed (this is - not possible if a Granger causality method is called). If "indices" is - specified, the seeds and targets are treated as a single connection. For - example, to compute the connectivity between signals 0, 1, 2 and 3, 4, 5, - one would use the same approach as above, however the signals would all be - considered for a single connection and the connectivity scores would have - the shape (1, n_freqs). + None, connectivity between all signals will be computed and a single + connectivity spectrum will be returned (this is not possible if a Granger + causality method is called). If "indices" is specified, seed and target + indices for each connection should be specified as nested array-likes. For + example, to compute the connectivity between signals (0, 1) -> (2, 3) and + (0, 1) -> (4, 5), indices should be specified as:: + + indices = (np.array([[0, 1], [0, 1]]), # seeds + np.array([[2, 3], [4, 5]])) # targets + + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. **Supported Connectivity Measures** @@ -1834,11 +1884,15 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, # check rank input and compute data ranks if necessary if multivariate_con: - rank = _check_rank_input(rank, data, sfreq, indices_use) + rank = _check_rank_input(rank, data, indices_use) else: rank = None gc_n_lags = None + # make sure padded indices are stored in the connectivity object + if multivariate_con and indices is not None: + indices = tuple(np.array(indices_use)) # create a copy + # get the window function, wavelets, etc for different modes (spectral_params, mt_adaptive, n_times_spectrum, n_tapers) = _assemble_spectral_params( @@ -1848,16 +1902,33 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, cwt_freqs=cwt_freqs, freqs=freqs, freq_mask=freq_mask) # unique signals for which we actually need to compute PSD etc. - sig_idx = np.unique(np.r_[indices_use[0], indices_use[1]]) + if multivariate_con: + sig_idx = np.unique(np.concatenate(np.concatenate( + indices_use))) + sig_idx = sig_idx[sig_idx != -1] + remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(sig_idx)} + remapping[-1] = -1 + remapped_inds = (indices_use[0].copy(), indices_use[1].copy()) + con_i = 0 + for seed, target in zip(indices_use[0], indices_use[1]): + remapped_inds[0][con_i] = np.array([ + remapping[idx] for idx in seed]) + remapped_inds[1][con_i] = np.array([ + remapping[idx] for idx in target]) + con_i += 1 + remapped_sig = [remapping[idx] for idx in sig_idx] + else: + sig_idx = np.unique(np.r_[indices_use[0], indices_use[1]]) n_signals_use = len(sig_idx) # map indices to unique indices - idx_map = [np.searchsorted(sig_idx, ind) for ind in indices_use] if multivariate_con: - indices_use = idx_map - idx_map = np.unique([*idx_map[0], *idx_map[1]]) - idx_map = [np.sort(np.repeat(idx_map, len(sig_idx))), - np.tile(idx_map, len(sig_idx))] + indices_use = remapped_inds # use remapped seeds & targets + idx_map = [np.sort(np.repeat(remapped_sig, len(sig_idx))), + np.tile(remapped_sig, len(sig_idx))] + else: + idx_map = [ + np.searchsorted(sig_idx, ind) for ind in indices_use] # allocate space to accumulate PSD if accumulate_psd: @@ -1894,7 +1965,7 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, call_params = dict( sig_idx=sig_idx, tmin_idx=tmin_idx, tmax_idx=tmax_idx, sfreq=sfreq, method=method, mode=mode, freq_mask=freq_mask, idx_map=idx_map, - block_size=block_size, + n_cons=n_cons, block_size=block_size, psd=psd, accumulate_psd=accumulate_psd, mt_adaptive=mt_adaptive, con_method_types=con_method_types, @@ -1978,8 +2049,8 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, this_con = this_con_bands if this_patterns is not None: - patterns_shape = ((2, n_cons, len(indices[0]), n_bands) + - this_patterns.shape[4:]) + patterns_shape = list(this_patterns.shape) + patterns_shape[3] = n_bands this_patterns_bands = np.empty(patterns_shape, dtype=this_patterns.dtype) for band_idx in range(n_bands): @@ -2023,11 +2094,6 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, # number of nodes in the original data n_nodes = n_signals - if multivariate_con: - # UNTIL RAGGED ARRAYS SUPPORTED - indices = tuple( - [[np.array(indices_use[0])], [np.array(indices_use[1])]]) - # create a list of connectivity containers conn_list = [] for _con, _patterns, _method in zip(con, patterns, method): @@ -2054,46 +2120,3 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, conn_list = conn_list[0] return conn_list - - -def _check_rank_input(rank, data, sfreq, indices): - """Check the rank argument is appropriate and compute rank if missing.""" - # UNTIL RAGGED ARRAYS SUPPORTED - indices = np.array([[indices[0]], [indices[1]]]) - - if rank is None: - - rank = np.zeros((2, len(indices[0])), dtype=int) - - if isinstance(data, BaseEpochs): - data_arr = data.get_data() - else: - data_arr = data - - for group_i in range(2): - for con_i, con_idcs in enumerate(indices[group_i]): - s = np.linalg.svd(data_arr[:, con_idcs], compute_uv=False) - rank[group_i][con_i] = np.min( - [np.count_nonzero(epoch >= epoch[0] * 1e-10) - for epoch in s]) - - logger.info('Estimated data ranks:') - con_i = 1 - for seed_rank, target_rank in zip(rank[0], rank[1]): - logger.info(' connection %i - seeds (%i); targets (%i)' - % (con_i, seed_rank, target_rank, )) - con_i += 1 - - rank = tuple((np.array(rank[0]), np.array(rank[1]))) - - else: - for seed_idcs, target_idcs, seed_rank, target_rank in zip( - indices[0], indices[1], rank[0], rank[1]): - if not (0 < seed_rank <= len(seed_idcs) and - 0 < target_rank <= len(target_idcs)): - raise ValueError( - 'ranks for seeds and targets must be > 0 and <= the ' - 'number of channels in the seeds and targets, ' - 'respectively, for each connection') - - return rank diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index fa8cf44d..a436aec8 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -423,7 +423,9 @@ def test_spectral_connectivity_epochs_multivariate(method): trans_bandwidth = 2.0 # Hz delay = 10 # samples (non-zero delay needed for ImCoh and GC to be >> 0) - indices = tuple([np.arange(n_seeds), np.arange(n_seeds) + n_seeds]) + indices = (np.arange(n_seeds)[np.newaxis, :], + np.arange(n_seeds)[np.newaxis, :] + n_seeds) + n_targets = n_seeds # 15-25 Hz connectivity fstart, fend = 15.0, 25.0 @@ -494,8 +496,17 @@ def test_spectral_connectivity_epochs_multivariate(method): if method in ['mic', 'mim']: con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=None, sfreq=sfreq) - assert (np.array(con.indices).tolist() == - [[[0, 1, 2, 3]], [[0, 1, 2, 3]]]) + assert con.indices is None + assert con.n_nodes == n_signals + if method == 'mic': + assert np.array(con.attrs['patterns']).shape[2] == n_signals + + # check ragged indices padded correctly + ragged_indices = (np.array([[0]]), np.array([[1, 2]])) + con = spectral_connectivity_epochs( + data, method=method, mode=mode, indices=ragged_indices, sfreq=sfreq) + assert np.all(np.array(con.indices) == + np.array([np.array([[0, -1]]), np.array([[1, 2]])])) # check shape of MIC patterns if method == 'mic': @@ -507,12 +518,12 @@ def test_spectral_connectivity_epochs_multivariate(method): if mode == 'cwt_morlet': patterns_shape = ( - (len(indices[0]), len(con.freqs), len(con.times)), - (len(indices[1]), len(con.freqs), len(con.times))) + (n_seeds, len(con.freqs), len(con.times)), + (n_targets, len(con.freqs), len(con.times))) else: patterns_shape = ( - (len(indices[0]), len(con.freqs)), - (len(indices[1]), len(con.freqs))) + (n_seeds, len(con.freqs)), + (n_targets, len(con.freqs))) assert np.shape(con.attrs["patterns"][0][0]) == patterns_shape[0] assert np.shape(con.attrs["patterns"][1][0]) == patterns_shape[1] @@ -532,10 +543,22 @@ def test_spectral_connectivity_epochs_multivariate(method): con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=rank) - assert (np.shape(con.attrs["patterns"][0][0])[0] == - len(indices[0])) - assert (np.shape(con.attrs["patterns"][1][0])[0] == - len(indices[1])) + assert (np.shape(con.attrs["patterns"][0][0])[0] == n_seeds) + assert (np.shape(con.attrs["patterns"][1][0])[0] == n_targets) + + # check patterns padded correctly + ragged_indices = (np.array([[0]]), np.array([[1, 2]])) + con = spectral_connectivity_epochs( + data, method=method, mode=mode, indices=ragged_indices, + sfreq=sfreq) + patterns = np.array(con.attrs["patterns"]) + patterns_shape = ( + (n_seeds, len(con.freqs)), (n_targets, len(con.freqs))) + assert patterns[0, 0].shape == patterns_shape[0] + assert patterns[1, 0].shape == patterns_shape[1] + assert not np.any(np.isnan(patterns[0, 0, 0])) + assert np.all(np.isnan(patterns[0, 0, 1])) + assert not np.any(np.isnan(patterns[1, 0])) def test_multivariate_spectral_connectivity_epochs_regression(): @@ -558,7 +581,7 @@ def test_multivariate_spectral_connectivity_epochs_regression(): data = pd.read_pickle( os.path.join(fpath, 'data', 'example_multivariate_data.pkl')) sfreq = 100 - indices = tuple([[0, 1], [2, 3]]) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) methods = ['mic', 'mim', 'gc', 'gc_tr'] con = spectral_connectivity_epochs( data, method=methods, indices=indices, mode='multitaper', sfreq=sfreq, @@ -587,13 +610,21 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): n_times = 256 rng = np.random.RandomState(0) data = rng.randn(n_epochs, n_signals, n_times) - indices = (np.arange(0, 2), np.arange(2, 4)) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) cwt_freqs = np.arange(10, 25 + 1) - # check bad indices with repeated channels + # check bad indices without nested array caught + with pytest.raises(TypeError, + match='multivariate indices must contain array-likes'): + non_nested_indices = (np.array([0, 1]), np.array([2, 3])) + spectral_connectivity_epochs( + data, method=method, mode=mode, indices=non_nested_indices, + sfreq=sfreq, gc_n_lags=10) + + # check bad indices with repeated channels caught with pytest.raises(ValueError, - match='seed and target indices cannot contain'): - repeated_indices = tuple([[0, 1, 1], [2, 2, 3]]) + match='multivariate indices cannot contain repeated'): + repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) spectral_connectivity_epochs( data, method=method, mode=mode, indices=repeated_indices, sfreq=sfreq, gc_n_lags=10) @@ -644,7 +675,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): sfreq=sfreq, rank=(np.array([2]), np.array([2])), cwt_freqs=cwt_freqs) - # only check these once for speed + # only check these once (e.g. only with multitaper) for speed if method == 'gc' and mode == 'multitaper': # check bad n_lags caught frange = (5, 10) @@ -662,7 +693,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): cwt_freqs=cwt_freqs) # check intersecting indices caught - bad_indices = (np.array([0, 1]), np.array([0, 2])) + bad_indices = (np.array([[0, 1]]), np.array([[0, 2]])) with pytest.raises(ValueError, match='seed and target indices must not intersect'): spectral_connectivity_epochs(data, method=method, mode=mode, @@ -695,7 +726,7 @@ def test_multivar_spectral_connectivity_parallel(method): n_times = 256 rng = np.random.RandomState(0) data = rng.randn(n_epochs, n_signals, n_times) - indices = (np.arange(0, 2), np.arange(2, 4)) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) spectral_connectivity_epochs( data, method=method, mode="multitaper", indices=indices, sfreq=sfreq, @@ -854,7 +885,7 @@ def test_spectral_connectivity_time_delayed(): trans_bandwidth = 2.0 # Hz delay = 5 # samples (non-zero delay needed for GC to be >> 0) - indices = tuple([np.arange(n_seeds), np.arange(n_seeds) + n_seeds]) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) # 20-30 Hz connectivity fstart, fend = 20.0, 30.0 @@ -1058,7 +1089,8 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): n_times = 500 rng = np.random.RandomState(0) data = rng.randn(n_epochs, n_signals, n_times) - indices = (np.arange(0, 2), np.arange(2, 4)) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) + n_cons = len(indices[0]) freqs = np.arange(10, 25 + 1) con_shape = [1] @@ -1077,34 +1109,60 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): # check shape of MIC patterns are correct if method == 'mic': - patterns_shape = [len(indices[0])] - if faverage: - patterns_shape.append(1) - else: - patterns_shape.append(len(freqs)) - if not average: - patterns_shape = [n_epochs, *patterns_shape] - patterns_shape = [2, *patterns_shape] - assert np.array(con.attrs['patterns']).shape == tuple(patterns_shape) + for indices_type in ['full', 'ragged']: + if indices_type == 'full': + indices = (np.array([[0, 1]]), np.array([[2, 3]])) + else: + indices = (np.array([[0, 1]]), np.array([[2]])) + max_n_chans = 2 + patterns_shape = [n_cons, max_n_chans] + if faverage: + patterns_shape.append(1) + else: + patterns_shape.append(len(freqs)) + if not average: + patterns_shape = [n_epochs, *patterns_shape] + patterns_shape = [2, *patterns_shape] + con = spectral_connectivity_time( + data, freqs, indices=indices, method=method, sfreq=sfreq, + faverage=faverage, average=average, gc_n_lags=10) -@pytest.mark.parametrize( - 'method', ['mic', 'mim', 'gc', 'gc_tr']) -def test_multivar_spectral_connectivity_time_error_catch(method): + patterns = np.array(con.attrs['patterns']) + # 2 (x epochs) x cons x channels x freqs|fbands + assert (patterns.shape == tuple(patterns_shape)) + if indices_type == 'ragged': + assert not np.any(np.isnan(patterns[0, ..., :, :])) + assert not np.any(np.isnan(patterns[0, ..., 0, :])) + assert np.all(np.isnan(patterns[1, ..., 1, :])) # padded entry + assert np.all(np.array(con.indices) == np.array( + (np.array([[0, 1]]), np.array([[2, -1]])))) + + +@pytest.mark.parametrize('method', ['mic', 'mim', 'gc', 'gc_tr']) +@pytest.mark.parametrize('mode', ['multitaper', 'cwt_morlet']) +def test_multivar_spectral_connectivity_time_error_catch(method, mode): """Test error catching for time-resolved multivar. connectivity methods.""" sfreq = 50. n_signals = 4 # Do not change! n_epochs = 8 n_times = 256 data = np.random.rand(n_epochs, n_signals, n_times) - indices = (np.arange(0, 2), np.arange(2, 4)) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) freqs = np.arange(10, 25 + 1) - # check bad indices with repeated channels + # check bad indices without nested array caught + with pytest.raises(TypeError, + match='multivariate indices must contain array-likes'): + non_nested_indices = (np.array([0, 1]), np.array([2, 3])) + spectral_connectivity_time(data, freqs, method=method, mode=mode, + indices=non_nested_indices, sfreq=sfreq) + + # check bad indices with repeated channels caught with pytest.raises(ValueError, - match='seed and target indices cannot contain'): - repeated_indices = tuple([[0, 1, 1], [2, 2, 3]]) - spectral_connectivity_time(data, freqs, method=method, + match='multivariate indices cannot contain repeated'): + repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) + spectral_connectivity_time(data, freqs, method=method, mode=mode, indices=repeated_indices, sfreq=sfreq) # check mixed methods caught @@ -1112,7 +1170,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method): match='bivariate and multivariate connectivity'): mixed_methods = [method, 'coh'] spectral_connectivity_time(data, freqs, method=mixed_methods, - indices=indices, sfreq=sfreq) + mode=mode, indices=indices, sfreq=sfreq) # check bad rank args caught too_low_rank = (np.array([0]), np.array([0])) @@ -1120,38 +1178,40 @@ def test_multivar_spectral_connectivity_time_error_catch(method): match='ranks for seeds and targets must be'): spectral_connectivity_time( data, freqs, method=method, indices=indices, sfreq=sfreq, - rank=too_low_rank) + mode=mode, rank=too_low_rank) too_high_rank = (np.array([3]), np.array([3])) with pytest.raises(ValueError, match='ranks for seeds and targets must be'): spectral_connectivity_time( data, freqs, method=method, indices=indices, sfreq=sfreq, - rank=too_high_rank) + mode=mode, rank=too_high_rank) # check all-to-all conn. computed for MIC/MIM when no indices given if method in ['mic', 'mim']: - con = spectral_connectivity_epochs( - data, freqs, method=method, indices=None, sfreq=sfreq) - assert (np.array(con.indices).tolist() == - [[[0, 1, 2, 3]], [[0, 1, 2, 3]]]) + con = spectral_connectivity_time( + data, freqs, method=method, indices=None, sfreq=sfreq, mode=mode) + assert con.indices is None + assert con.n_nodes == n_signals + if method == 'mic': + assert np.array(con.attrs['patterns']).shape[3] == n_signals if method in ['gc', 'gc_tr']: # check no indices caught with pytest.raises(ValueError, match='indices must be specified'): - spectral_connectivity_time(data, freqs, method=method, + spectral_connectivity_time(data, freqs, method=method, mode=mode, indices=None, sfreq=sfreq) # check intersecting indices caught - bad_indices = (np.array([0, 1]), np.array([0, 2])) + bad_indices = (np.array([[0, 1]]), np.array([[0, 2]])) with pytest.raises(ValueError, match='seed and target indices must not intersect'): - spectral_connectivity_time(data, freqs, method=method, + spectral_connectivity_time(data, freqs, method=method, mode=mode, indices=bad_indices, sfreq=sfreq) # check bad fmin/fmax caught with pytest.raises(ValueError, match='computing Granger causality on multiple'): - spectral_connectivity_time(data, freqs, method=method, + spectral_connectivity_time(data, freqs, method=method, mode=mode, indices=indices, sfreq=sfreq, fmin=(5., 15.), fmax=(15., 30.)) @@ -1159,7 +1219,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method): def test_save(tmp_path): """Test saving results of spectral connectivity.""" rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq, f = 10, 2, 2000, 1000., 20. + n_epochs, n_chs, n_times, sfreq, f = 10, 3, 2000, 1000., 20. data = rng.randn(n_epochs, n_chs, n_times) sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000) data[:, :, 500:1500] += sig @@ -1171,3 +1231,10 @@ def test_save(tmp_path): epochs, fmin=(4, 8, 13, 30), fmax=(8, 13, 30, 45), faverage=True) conn.save(tmp_path / 'foo.nc') + + # multivariate connectivity + # use ragged indices & MIC to test padding of indices and patterns + indices = (np.array([[0, 1]]), np.array([[2]])) + conn_mvc = spectral_connectivity_epochs( + epochs, method="mic", indices=indices, sfreq=sfreq, fmin=10, fmax=40) + conn_mvc.save(tmp_path / 'foo_mvc.nc') diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 7c1aabe6..6b5eb000 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -16,7 +16,7 @@ from .epochs import (_MICEst, _MIMEst, _GCEst, _GCTREst, _compute_freq_mask, _check_rank_input) from .smooth import _create_kernel, _smooth_spectra -from ..utils import check_indices, fill_doc +from ..utils import check_indices, check_multivariate_indices, fill_doc _multivariate_methods = ['mic', 'mim', 'gc', 'gc_tr'] @@ -70,10 +70,11 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, :class:`EpochSpectralConnectivity`. indices : tuple of array_like | None Two arrays with indices of connections for which to compute - connectivity. If a multivariate method is called, the indices are for a - single connection between all seeds and all targets. If None, all - connections are computed, unless a Granger causality method is called, - in which case an error is raised. + connectivity. If a multivariate method is called, each array for the + seeds and targets should contain a nested array of channel indices for + the individual connections. If None, connections between all channels + are computed, unless a Granger causality method is called, in which + case an error is raised. sfreq : float The sampling frequency. Required if data is not :class:`Epochs `. @@ -144,11 +145,11 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, :class:`EpochSpectralConnectivity`, :class:`SpectralConnectivity` or a list of instances corresponding to connectivity measures if several connectivity measures are specified. - The shape of each connectivity dataset is - (n_epochs, n_signals, n_signals, n_freqs) when ``indices`` is `None`, - (n_epochs, n_nodes, n_nodes, n_freqs) when ``indices`` is specified - and ``n_nodes = len(indices[0])``, or (n_epochs, 1, 1, n_freqs) when a - multi-variate method is called regardless of "indices". + The shape of each connectivity dataset is (n_epochs, n_cons, n_freqs). + When "indices" is None and a bivariate method is called, + "n_cons = n_signals ** 2", or if a multivariate method is called + "n_cons = 1". When "indices" is specified, "n_con = len(indices[0])" + for bivariate and multivariate methods. See Also -------- @@ -202,13 +203,19 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, scores are in the same order as defined indices. For multivariate methods, this is handled differently. If "indices" is - None, connectivity between all signals will attempt to be computed (this is - not possible if a Granger causality method is called). If "indices" is - specified, the seeds and targets are treated as a single connection. For - example, to compute the connectivity between signals 0, 1, 2 and 3, 4, 5, - one would use the same approach as above, however the signals would all be - considered for a single connection and the connectivity scores would have - the shape (1, n_freqs). + None, connectivity between all signals will be computed and a single + connectivity spectrum will be returned (this is not possible if a Granger + causality method is called). If "indices" is specified, seed and target + indices for each connection should be specified as nested array-likes. For + example, to compute the connectivity between signals (0, 1) -> (2, 3) and + (0, 1) -> (4, 5), indices should be specified as:: + + indices = (np.array([[0, 1], [0, 1]]), # seeds + np.array([[2, 3], [4, 5]])) # targets + + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. **Supported Connectivity Measures** @@ -398,36 +405,51 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, 'indices must be specified when computing Granger ' 'causality, as all-to-all connectivity is not supported') logger.info('using all indices for multivariate connectivity') - indices_use = (np.arange(n_signals, dtype=int), - np.arange(n_signals, dtype=int)) + indices_use = (np.array([np.arange(n_signals, dtype=np.int32)]), + np.array([np.arange(n_signals, dtype=np.int32)])) else: logger.info('only using indices for lower-triangular matrix') indices_use = np.tril_indices(n_signals, k=-1) else: if multivariate_con: - if ( - len(np.unique(indices[0])) != len(indices[0]) or - len(np.unique(indices[1])) != len(indices[1]) - ): - raise ValueError( - 'seed and target indices cannot contain repeated ' - 'channels for multivariate connectivity') + indices_use = check_multivariate_indices(indices) # pad with -1 if any(this_method in _gc_methods for this_method in method): - if set(indices[0]).intersection(indices[1]): - raise ValueError( - 'seed and target indices must not intersect when ' - 'computing Granger causality') - indices_use = check_indices(indices) - source_idx = indices_use[0] - target_idx = indices_use[1] - n_pairs = len(source_idx) if not multivariate_con else 1 + for seed, target in zip(indices[0], indices[1]): + intersection = np.intersect1d(seed, target) + if np.any(intersection != -1): # ignore padded entries + raise ValueError( + 'seed and target indices must not intersect when ' + 'computing Granger causality') + # make sure padded indices are stored in the connectivity object + indices = tuple(np.array(indices_use)) # create a copy + else: + indices_use = check_indices(indices) + # create copies of indices_use for independent manipulation + source_idx = np.array(indices_use[0]) + target_idx = np.array(indices_use[1]) + n_cons = len(source_idx) # unique signals for which we actually need to compute the CSD of - signals_use = np.unique(np.r_[indices_use[0], indices_use[1]]) + if multivariate_con: + signals_use = np.unique(np.concatenate(np.concatenate(indices_use))) + signals_use = signals_use[signals_use != -1] + remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(signals_use)} + remapping[-1] = -1 + # multivariate functions expect seed/target remapping + con_i = 0 + for seed, target in zip(indices_use[0], indices_use[1]): + source_idx[con_i] = np.array([remapping[idx] for idx in seed]) + target_idx[con_i] = np.array([remapping[idx] for idx in target]) + con_i += 1 + max_n_channels = len(indices_use[0][0]) + else: + # no indices remapping required for bivariate functions + signals_use = np.unique(np.r_[indices_use[0], indices_use[1]]) + max_n_channels = len(indices_use[0]) # check rank input and compute data ranks if necessary if multivariate_con: - rank = _check_rank_input(rank, data, sfreq, indices_use) + rank = _check_rank_input(rank, data, indices_use) else: rank = None gc_n_lags = None @@ -479,9 +501,10 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, conn = dict() conn_patterns = dict() for m in method: - conn[m] = np.zeros((n_epochs, n_pairs, n_freqs)) - conn_patterns[m] = np.full((n_epochs, 2, len(source_idx), n_freqs), - np.nan) + conn[m] = np.zeros((n_epochs, n_cons, n_freqs)) + # patterns shape of [epochs x seeds/targets x cons x channels x freqs] + conn_patterns[m] = np.full( + (n_epochs, 2, n_cons, max_n_channels, n_freqs), np.nan) logger.info('Connectivity computation...') # parameters to pass to the connectivity function @@ -505,8 +528,8 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, if np.isnan(conn_patterns[m]).all(): conn_patterns[m] = None else: - # epochs x 2 x n_channels x n_freqs - conn_patterns[m] = conn_patterns[m].transpose((1, 0, 2, 3)) + # transpose to [seeds/targets x epochs x cons x channels x freqs] + conn_patterns[m] = conn_patterns[m].transpose((1, 0, 2, 3, 4)) if indices is None and not multivariate_con: conn_flat = conn @@ -520,11 +543,6 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, conn_flat[m].shape[2:]) conn[m] = this_conn - if multivariate_con: - # UNTIL RAGGED ARRAYS SUPPORTED - indices = tuple( - [[np.array(indices_use[0])], [np.array(indices_use[1])]]) - # create the connectivity containers out = [] for m in method: @@ -569,9 +587,9 @@ def _spectral_connectivity(data, method, kernel, foi_idx, Smoothing kernel. foi_idx : array_like, shape (n_foi, 2) Upper and lower bound indices of frequency bands. - source_idx : array_like, shape (n_pairs,) + source_idx : array_like, shape (n_cons,) or (n_cons, n_channels) Defines the signal pairs of interest together with ``target_idx``. - target_idx : array_like, shape (n_pairs,) + target_idx : array_like, shape (n_cons,) or (n_cons, n_channels) Defines the signal pairs of interest together with ``source_idx``. signals_use : list of int The unique signals on which connectivity is to be computed. @@ -608,8 +626,8 @@ def _spectral_connectivity(data, method, kernel, foi_idx, ------- scores : dict Dictionary containing the connectivity estimates corresponding to the - metrics in ``method``. Each element is an array of shape (n_pairs, - n_freqs) or (n_pairs, n_fbands) if ``faverage`` is `True`. + metrics in ``method``. Each element is an array of shape (n_cons, + n_freqs) or (n_cons, n_fbands) if ``faverage`` is `True`. patterns : dict Dictionary containing the connectivity patterns (for reconstructing the @@ -619,7 +637,7 @@ def _spectral_connectivity(data, method, kernel, foi_idx, or (2, n_channels, 1) if ``faverage`` is `True`, where 2 corresponds to the seed and target signals (respectively). """ - n_pairs = len(source_idx) + n_cons = len(source_idx) data = np.expand_dims(data, axis=0) if mode == 'cwt_morlet': out = tfr_array_morlet( @@ -665,12 +683,12 @@ def _spectral_connectivity(data, method, kernel, foi_idx, scores = {} patterns = {} conn = _parallel_con(out, method, kernel, foi_idx, source_idx, target_idx, - signals_use, gc_n_lags, rank, n_jobs, verbose, - n_pairs, faverage, weights, multivariate_con) + signals_use, gc_n_lags, rank, n_jobs, verbose, n_cons, + faverage, weights, multivariate_con) for i, m in enumerate(method): if multivariate_con: scores[m] = conn[0][i] - patterns[m] = conn[1][i][:, 0] if conn[1][i] is not None else None + patterns[m] = conn[1][i] if conn[1][i] is not None else None else: scores[m] = [out[i] for out in conn] patterns[m] = None @@ -699,16 +717,16 @@ def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, Smoothing kernel. foi_idx : array_like, shape (n_foi, 2) Upper and lower bound indices of frequency bands. - source_idx : array_like, shape (n_pairs,) + source_idx : array_like, shape (n_cons,) or (n_cons, n_channels) Defines the signal pairs of interest together with ``target_idx``. - target_idx : array_like, shape (n_pairs,) + target_idx : array_like, shape (n_cons,) or (n_cons, n_channels) Defines the signal pairs of interest together with ``source_idx``. signals_use : list of int The unique signals on which connectivity is to be computed. gc_n_lags : int Number of lags to use for the vector autoregressive model when computing Granger causality. - rank : tuple of array + rank : tuple of array of int Ranks to project the seed and target data to. n_jobs : int Number of parallel jobs. @@ -825,18 +843,23 @@ def _pairwise_con(w, psd, x, y, method, kernel, foi_idx, return out -def _multivariate_con(w, source_idx, target_idx, signals_use, method, kernel, - foi_idx, faverage, weights, gc_n_lags, rank, n_jobs): +def _multivariate_con(w, seeds, targets, signals_use, method, kernel, foi_idx, + faverage, weights, gc_n_lags, rank, n_jobs): """Compute spectral connectivity metrics between multiple signals. Parameters ---------- w : array_like, shape (n_chans, n_tapers, n_freqs, n_times) Time-frequency data. - x : int - Channel index. - y : int - Channel index. + seeds : array, shape of (n_cons, n_channels) + Seed channel indices. ``n_channels`` is the largest number of channels + across all connections, with missing entries padded with ``-1``. + targets : array, shape of (n_cons, n_channels) + Target channel indices. ``n_channels`` is the largest number of + channels across all connections, with missing entries padded with + ``-1``. + signals_use : list of int + The unique signals on which connectivity is to be computed. method : str Connectivity method. kernel : array_like, shape (n_sm_fres, n_sm_times) @@ -847,6 +870,13 @@ def _multivariate_con(w, source_idx, target_idx, signals_use, method, kernel, Average over frequency bands. weights : array_like, shape (n_tapers, n_freqs, n_times) | None Multitaper weights. + gc_n_lags : int + Number of lags to use for the vector autoregressive model when + computing Granger causality. + rank : tuple of array, shape of (2, n_cons) + Ranks to project the seed and target data to. + n_jobs : int + Number of jobs to run in parallel. Returns ------- @@ -859,8 +889,10 @@ def _multivariate_con(w, source_idx, target_idx, signals_use, method, kernel, List of connectivity patterns between seed and target signals for each connectivity method. Each element is an array of length 2 corresponding to the seed and target patterns, respectively, each with shape - (n_channels, n_freqs,) or (n_channels, n_fbands) depending on - ``faverage``. + (n_channels, n_freqs) or (n_channels, n_fbands) + depending on ``faverage``. ``n_channels`` is the largest number of + channels across all connections, with missing entries padded with + ``np.nan``. """ csd = [] for x in signals_use: @@ -880,8 +912,7 @@ def _multivariate_con(w, source_idx, target_idx, signals_use, method, kernel, 'gc_tr': _GCTREst} conn = [] for m in method: - # N_CONS = 1 UNTIL RAGGED ARRAYS SUPPORTED - call_params = {'n_signals': len(signals_use), 'n_cons': 1, + call_params = {'n_signals': len(signals_use), 'n_cons': len(seeds), 'n_freqs': csd.shape[1], 'n_times': 0, 'n_jobs': n_jobs} if m in _gc_methods: @@ -895,7 +926,7 @@ def _multivariate_con(w, source_idx, target_idx, signals_use, method, kernel, scores = [] patterns = [] for con_est in conn: - con_est.compute_con(np.array([source_idx, target_idx]), rank) + con_est.compute_con((seeds, targets), rank) scores.append(con_est.con_scores[..., np.newaxis]) patterns.append(con_est.patterns) if patterns[-1] is not None: diff --git a/mne_connectivity/tests/test_utils.py b/mne_connectivity/tests/test_utils.py index c012481c..0549ee43 100644 --- a/mne_connectivity/tests/test_utils.py +++ b/mne_connectivity/tests/test_utils.py @@ -3,11 +3,15 @@ from numpy.testing import assert_array_equal from mne_connectivity import Connectivity -from mne_connectivity.utils import degree, seed_target_indices +from mne_connectivity.utils import (degree, check_indices, + check_multivariate_indices, + seed_target_indices, + seed_target_multivariate_indices) -def test_indices(): - """Test connectivity indexing methods.""" +def test_seed_target_indices(): + """Test indices generation functions.""" + # bivariate indices n_seeds_test = [1, 3, 4] n_targets_test = [2, 3, 200] rng = np.random.RandomState(42) @@ -25,6 +29,69 @@ def test_indices(): for target in targets: assert np.sum(indices[1] == target) == n_seeds + # multivariate indices + # non-ragged indices + seeds = [[0, 1]] + targets = [[2, 3], [3, 4]] + indices = seed_target_multivariate_indices(seeds, targets) + assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]), + np.array([[2, 3], [3, 4]]))) + # ragged indices + seeds = [[0, 1]] + targets = [[2, 3, 4], [4]] + indices = seed_target_multivariate_indices(seeds, targets) + assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]), + np.array([[2, 3, 4], [4, -1, -1]]))) + # test error catching + # non-array-like seeds/targets + with pytest.raises(TypeError, + match='`seeds` and `targets` must be array-like'): + seed_target_multivariate_indices(0, 1) + # non-nested seeds/targets + with pytest.raises(TypeError, + match='`seeds` and `targets` must contain nested'): + seed_target_multivariate_indices([0], [1]) + + +def test_check_indices(): + """Test indices checking functions.""" + # bivariate indices + # test error catching + with pytest.raises(ValueError, + match='indices must be a tuple of length 2'): + non_tuple_indices = [[0], [1]] + check_indices(non_tuple_indices) + with pytest.raises(ValueError, + match='indices must be a tuple of length 2'): + non_len2_indices = ([0], [1], [2]) + check_indices(non_len2_indices) + with pytest.raises(ValueError, match='Index arrays indices'): + non_equal_len_indices = ([0], [1, 2]) + check_indices(non_equal_len_indices) + + # multivariate indices + # non-ragged indices + seeds = [[0, 1], [0, 1]] + targets = [[2, 3], [3, 4]] + indices = check_multivariate_indices((seeds, targets)) + assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]), + np.array([[2, 3], [3, 4]]))) + # ragged indices + seeds = [[0, 1], [0, 1]] + targets = [[2, 3, 4], [4]] + indices = check_multivariate_indices((seeds, targets)) + assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]), + np.array([[2, 3, 4], [4, -1, -1]]))) + # test error catching + with pytest.raises(TypeError, + match='multivariate indices must contain array-likes'): + non_nested_indices = (np.array([0, 1]), np.array([2, 3])) + check_multivariate_indices(non_nested_indices) + with pytest.raises(ValueError, + match='multivariate indices cannot contain repeated'): + repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) + check_multivariate_indices(repeated_indices) + def test_degree(): """Test degree function.""" diff --git a/mne_connectivity/utils/__init__.py b/mne_connectivity/utils/__init__.py index e82f054b..0df454a4 100644 --- a/mne_connectivity/utils/__init__.py +++ b/mne_connectivity/utils/__init__.py @@ -1,3 +1,4 @@ from .docs import fill_doc -from .utils import (check_indices, degree, seed_target_indices, +from .utils import (check_indices, check_multivariate_indices, degree, + seed_target_indices, seed_target_multivariate_indices, parallel_loop, _prepare_xarray_mne_data_structures) diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index 5ae94acb..b8216654 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -71,6 +71,47 @@ def check_indices(indices): return indices +def check_multivariate_indices(indices): + """Check indices parameter for multivariate connectivity and pad it. + + Parameters + ---------- + indices : tuple of array-like of array-like of int + Tuple of length 2 containing index pairs. + + Returns + ------- + indices : tuple of array of array of int + The indices padded with the invalid channel index ``-1``. + """ + indices = check_indices(indices) + n_cons = len(indices[0]) + + n_chans = [] + for inds in ([*indices[0], *indices[1]]): + if not isinstance(inds, (np.ndarray, list, tuple)): + raise TypeError( + 'multivariate indices must contain array-likes of channel ' + 'indices for each seed and target') + if len(inds) != len(np.unique(inds)): + raise ValueError( + 'multivariate indices cannot contain repeated channels within ' + 'a seed or target') + n_chans.append(len(inds)) + max_n_chans = np.max(n_chans) + + # pad indices to avoid ragged arrays + padded_indices = (np.full((n_cons, max_n_chans), -1, dtype=np.int32), + np.full((n_cons, max_n_chans), -1, dtype=np.int32)) + con_i = 0 + for seed, target in zip(indices[0], indices[1]): + padded_indices[0][con_i, :len(seed)] = seed + padded_indices[1][con_i, :len(target)] = target + con_i += 1 + + return padded_indices + + def seed_target_indices(seeds, targets): """Generate indices parameter for seed based connectivity analysis. @@ -99,6 +140,60 @@ def seed_target_indices(seeds, targets): return indices +def seed_target_multivariate_indices(seeds, targets): + """Generate indices parameter for multivariate seed-based connectivity. + + Parameters + ---------- + seeds : array-like of array-like of int + Seed indices. + + targets : array-like of array-like of int + Target indices. + + Returns + ------- + indices : tuple of array of array of int + The indices padded with the invalid channel index ``-1``. + """ + array_like = (np.ndarray, list, tuple) + + if ( + not isinstance(seeds, array_like) or + not isinstance(targets, array_like) + ): + raise TypeError('`seeds` and `targets` must be array-like') + + n_chans = [] + for inds in [*seeds, *targets]: + if not isinstance(inds, array_like): + raise TypeError( + '`seeds` and `targets` must contain nested array-likes') + n_chans.append(len(inds)) + max_n_chans = max(n_chans) + n_cons = len(seeds) * len(targets) + + # pad indices to avoid ragged arrays + padded_seeds = np.full((len(seeds), max_n_chans), -1, dtype=np.int32) + padded_targets = np.full((len(targets), max_n_chans), -1, dtype=np.int32) + for con_i, seed in enumerate(seeds): + padded_seeds[con_i, :len(seed)] = seed + for con_i, target in enumerate(targets): + padded_targets[con_i, :len(target)] = target + + # create final indices + indices = (np.zeros((n_cons, max_n_chans), dtype=np.int32), + np.zeros((n_cons, max_n_chans), dtype=np.int32)) + con_i = 0 + for seed in padded_seeds: + for target in padded_targets: + indices[0][con_i] = seed + indices[1][con_i] = target + con_i += 1 + + return indices + + def degree(connectivity, threshold_prop=0.2): """Compute the undirected degree of a connectivity matrix. From 59abbcbe178d4587b57c1d1634e1e15e024d4967 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 24 Jul 2023 12:20:31 +0200 Subject: [PATCH 02/40] added author --- mne_connectivity/utils/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index b8216654..803556a8 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -1,4 +1,5 @@ # Authors: Martin Luessi +# Thomas S. Binns # # License: BSD (3-clause) import numpy as np From 8b8b0c2b147c9b9d91244d07bd0dbfbbb84eafc4 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 24 Jul 2023 12:36:42 +0200 Subject: [PATCH 03/40] bug fix ragged indices comparison --- mne_connectivity/spectral/epochs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index c8499341..5730f38a 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -578,7 +578,7 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, np.matmul(E, E.transpose(0, 1, 3, 2))) w_targets, V_targets = np.linalg.eigh( np.matmul(E.transpose(0, 1, 3, 2), E)) - if np.all(seed_idcs == target_idcs): + if np.all(np.unique(seed_idcs) == np.unique(target_idcs)): # strange edge-case where the eigenvectors returned should be a set # of identity matrices with one rotated by 90 degrees, but are # instead identical (i.e. are not rotated versions of one another). From d3ed2e964bd9412c03bd3851a764f3bb7804b7c6 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 24 Jul 2023 12:56:28 +0200 Subject: [PATCH 04/40] bug fix ragged indices comparison --- mne_connectivity/spectral/epochs.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 5730f38a..8efbb18e 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -578,12 +578,15 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, np.matmul(E, E.transpose(0, 1, 3, 2))) w_targets, V_targets = np.linalg.eigh( np.matmul(E.transpose(0, 1, 3, 2), E)) - if np.all(np.unique(seed_idcs) == np.unique(target_idcs)): + if ( + len(seed_idcs) == len(target_idcs) and + np.all(np.unique(seed_idcs) == np.unique(target_idcs)) + ): # strange edge-case where the eigenvectors returned should be a set # of identity matrices with one rotated by 90 degrees, but are # instead identical (i.e. are not rotated versions of one another). # This leads to the case where the spatial filters are incorrectly - # applied, resulting in connectivity estimates of e.g. ~0 when they + # applied, resulting in connectivity estimates of ~0 when they # should be perfectly correlated ~1. Accordingly, we manually # create a set of rotated identity matrices to use as the filters. create_filter = False @@ -630,7 +633,10 @@ def _compute_mim(self, E, seed_idcs, target_idcs, con_i): E, E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T # Eq. 15 - if np.all(np.unique(seed_idcs) == np.unique(target_idcs)): + if ( + len(seed_idcs) == len(target_idcs) and + np.all(np.unique(seed_idcs) == np.unique(target_idcs)) + ): self.con_scores[con_i] *= 0.5 def reshape_results(self): From 1a1d0d189ec576d5750bb5b4c68434d8e9f6d346 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 24 Jul 2023 12:57:24 +0200 Subject: [PATCH 05/40] bug fix ragged indices comparison --- mne_connectivity/spectral/epochs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 8efbb18e..fc617bd7 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -580,7 +580,7 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, np.matmul(E.transpose(0, 1, 3, 2), E)) if ( len(seed_idcs) == len(target_idcs) and - np.all(np.unique(seed_idcs) == np.unique(target_idcs)) + np.all(np.sort(seed_idcs) == np.sort(target_idcs)) ): # strange edge-case where the eigenvectors returned should be a set # of identity matrices with one rotated by 90 degrees, but are @@ -635,7 +635,7 @@ def _compute_mim(self, E, seed_idcs, target_idcs, con_i): # Eq. 15 if ( len(seed_idcs) == len(target_idcs) and - np.all(np.unique(seed_idcs) == np.unique(target_idcs)) + np.all(np.sort(seed_idcs) == np.sort(target_idcs)) ): self.con_scores[con_i] *= 0.5 From f053b5257cf096dd0063d0dea5df265e297e9b9e Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 25 Jul 2023 11:10:26 +0200 Subject: [PATCH 06/40] added extra multivariate indices unit test --- .../spectral/tests/test_spectral.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index a436aec8..262ddbb3 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -736,6 +736,53 @@ def test_multivar_spectral_connectivity_parallel(method): indices=indices, sfreq=sfreq, gc_n_lags=10, n_jobs=2) +def test_multivar_spectral_connectivity_flipped_indices(): + """Test multivar. indices structure maintained by connectivity methods.""" + sfreq = 50. + n_signals = 4 + n_epochs = 8 + n_times = 256 + rng = np.random.RandomState(0) + data = rng.randn(n_epochs, n_signals, n_times) + freqs = np.arange(10, 20) + + # if we're not careful, when finding the channels we need to compute the + # CSD for, we might accidentally reorder the connectivity indices + indices = (np.array([[0, 1]]), + np.array([[2, 3]])) + flipped_indices = (np.array([[2, 3]]), + np.array([[0, 1]])) + concat_indices = (np.array([[0, 1], [2, 3]]), + np.array([[2, 3], [0, 1]])) + + # we test on GC since this is a directed connectivity measure + method = 'gc' + + con_st = spectral_connectivity_epochs( # seed -> target + data, method=method, indices=indices, sfreq=sfreq, gc_n_lags=10) + con_ts = spectral_connectivity_epochs( # target -> seed + data, method=method, indices=flipped_indices, sfreq=sfreq, + gc_n_lags=10) + con_st_ts = spectral_connectivity_epochs( # seed -> target; target -> seed + data, method=method, indices=concat_indices, sfreq=sfreq, gc_n_lags=10) + assert not np.all(con_st.get_data() == con_ts.get_data()) + assert np.all(con_st.get_data()[0] == con_st_ts.get_data()[0]) + assert np.all(con_ts.get_data()[0] == con_st_ts.get_data()[1]) + + con_st = spectral_connectivity_time( # seed -> target + data, freqs, method=method, indices=indices, sfreq=sfreq, + gc_n_lags=10) + con_ts = spectral_connectivity_time( # target -> seed + data, freqs, method=method, indices=flipped_indices, sfreq=sfreq, + gc_n_lags=10) + con_st_ts = spectral_connectivity_time( # seed -> target; target -> seed + data, freqs, method=method, indices=concat_indices, sfreq=sfreq, + gc_n_lags=10) + assert not np.all(con_st.get_data() == con_ts.get_data()) + assert np.all(con_st.get_data()[:, 0] == con_st_ts.get_data()[:, 0]) + assert np.all(con_ts.get_data()[:, 0] == con_st_ts.get_data()[:, 1]) + + @ pytest.mark.parametrize('kind', ('epochs', 'ndarray', 'stc', 'combo')) def test_epochs_tmin_tmax(kind): """Test spectral.spectral_connectivity_epochs with epochs and arrays.""" From 90c5a73f9875ed9b0ddce87a0d0de65fe93e525a Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 14 Aug 2023 16:44:41 +0200 Subject: [PATCH 07/40] updated utils tests and docs --- doc/conf.py | 3 +- mne_connectivity/tests/test_utils.py | 4 + mne_connectivity/utils/utils.py | 120 +++++++++++++++++++++++---- 3 files changed, 112 insertions(+), 15 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index dc385650..724dfbc4 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -91,7 +91,8 @@ 'n_node_names', 'n_tapers', 'n_signals', 'n_step', 'n_freqs', 'epochs', 'freqs', 'times', 'arrays', 'lists', 'func', 'n_nodes', 'n_estimated_nodes', 'n_samples', 'n_channels', 'Renderer', - 'n_ytimes', 'n_ychannels', 'n_events' + 'n_ytimes', 'n_ychannels', 'n_events', 'n_cons', 'max_n_chans', + 'n_unique_seeds', 'n_unique_targets', 'variable' } numpydoc_xref_aliases = { # Python diff --git a/mne_connectivity/tests/test_utils.py b/mne_connectivity/tests/test_utils.py index 0549ee43..0d678da2 100644 --- a/mne_connectivity/tests/test_utils.py +++ b/mne_connectivity/tests/test_utils.py @@ -51,6 +51,10 @@ def test_seed_target_indices(): with pytest.raises(TypeError, match='`seeds` and `targets` must contain nested'): seed_target_multivariate_indices([0], [1]) + # repeated seeds/targets + with pytest.raises(ValueError, + match='`seeds` and `targets` cannot contain repeated'): + seed_target_multivariate_indices([[0, 1, 1]], [[2, 2, 3]]) def test_check_indices(): diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index 803556a8..515f7f53 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -50,17 +50,24 @@ def par(x): def check_indices(indices): - """Check indices parameter. + """Check indices parameter for bivariate connetivity. Parameters ---------- - indices : tuple of array - Tuple of length 2 containing index pairs. + indices : tuple of array of int, shape (2, n_cons) + Tuple containing index pairs. Returns ------- - indices : tuple of array + indices : tuple of array of int, shape (2, n_cons) The indices. + + Notes + ----- + Indices for bivariate connectivity should be a tuple of length 2, + containing the channel indices for the seed and target channel pairs, + respectively. Seed and target indices should be equal-length array-likes of + integers representing the indices of the individual channels in the data. """ if not isinstance(indices, tuple) or len(indices) != 2: raise ValueError('indices must be a tuple of length 2') @@ -77,13 +84,46 @@ def check_multivariate_indices(indices): Parameters ---------- - indices : tuple of array-like of array-like of int - Tuple of length 2 containing index pairs. + indices : tuple of array of array of int, shape (2, n_cons, variable) + Tuple containing index sets. Returns ------- - indices : tuple of array of array of int + indices : tuple of array of array of int, shape of (2, n_cons, max_n_chans) The indices padded with the invalid channel index ``-1``. + + Notes + ----- + Indices for multivariate connectivity should be a tuple of length 2 + containing the channel indices for the seed and target channel sets, + respectively. Seed and target indices should be equal-length array-likes + representing the indices of the channel sets in the data for each + connection. The indices for each connection should be an array-like of + integers representing the individual channels in the data. The length of + indices for each connection do not need to be equal. All indices within a + connection must be unique. + + If the seed and target indices are given as lists or tuples, they will be + converted to numpy arrays. In case the number of channels differs across + connections or between the seeds and targets for a given connection (i.e. + ragged indices), the returned array will be padded with the invalid channel + index ``-1`` according to the maximum number of channels in the seed or + target of any one connection. E.g. the ragged indices of shape `(2, n_cons, + variable)`:: + + indices = ([[0, 1], [0, 1 ]], # seeds + [[2, 3], [4, 5, 6]]) # targets + + would be returned as:: + + indices = (np.array([[0, 1, -1], [0, 1, -1]]), # seeds + np.array([[2, 3, -1], [4, 5, -1]])) # targets + + where the indices have been padded with ``-1`` to have shape `(2, n_cons, + max_n_chans)`, where `max_n_chans = 3`. More information on working with + multivariate indices and handling connections where the number of seeds and + targets are not equal can be found in the + :doc:`../auto_examples/handling_ragged_arrays` example. """ indices = check_indices(indices) n_cons = len(indices[0]) @@ -114,19 +154,37 @@ def check_multivariate_indices(indices): def seed_target_indices(seeds, targets): - """Generate indices parameter for seed based connectivity analysis. + """Generate indices parameter for bivariate seed-based connectivity. Parameters ---------- - seeds : array of int | int + seeds : array of int | int, shape (n_unique_seeds) Seed indices. - targets : array of int | int + targets : array of int | int, shape (n_unique_targets) Indices of signals for which to compute connectivity. Returns ------- - indices : tuple of array + indices : tuple of array of int, shape (2, n_cons) The indices parameter used for connectivity computation. + + Notes + ----- + `seeds` and `targets` should be array-likes or integers representing the + indices of the channel pairs in the data for each connection. `seeds` and + `targets` will be expanded such that connectivity will be computed between + each seed and each target. E.g. the seeds and targets:: + + seeds = [0, 1] + targets = [2, 3, 4] + + would be returned as:: + + indices = (np.array([0, 0, 0, 1, 1, 1]), # seeds + np.array([2, 3, 4, 2, 3, 4])) # targets + + where the indices have been expanded to have shape `(2, n_cons)`, where + `n_cons = n_unique_seeds * n_unique_targets`. """ # make them arrays seeds = np.asarray((seeds,)).ravel() @@ -146,16 +204,47 @@ def seed_target_multivariate_indices(seeds, targets): Parameters ---------- - seeds : array-like of array-like of int + seeds : array of array of int, shape (n_unique_seeds, variable) Seed indices. - targets : array-like of array-like of int + targets : array of array of int, shape (n_unique_targets, variable) Target indices. Returns ------- - indices : tuple of array of array of int + indices : tuple of array of array of int, shape (2, n_cons, max_n_chans) The indices padded with the invalid channel index ``-1``. + + Notes + ----- + `seeds` and `targets` should be array-likes representing the indices of the + channel sets in the data for each connection. The indices for each + connection should be an array-like of integers representing the individual + channels in the data. The length of indices for each connection do not need + to be equal. Furthermore, all indices within a connection must be unique. + + `seeds` and `targets` will be expanded such that connectivity will be + computed between each set of seeds and targets. In case the number of + channels differs across connections or between the seeds and targets for a + given connection (i.e. ragged indices), the returned array will be padded + with the invalid channel index ``-1`` according to the maximum number of + channels in the seed or target of any one connection. E.g. `seeds` and + `targets`:: + + seeds = [[0]] + targets = [[1, 2], [3, 4, 5]] + + would be returned as:: + + indices = (np.array([[0, -1, -1], [0, -1, -1]]), # seeds + np.array([[1, 2, -1], [3, 4, 5]])) # targets + + where the indices have been padded with ``-1`` to have shape `(2, n_cons, + max_n_chans)`, where `n_cons = n_unique_seeds * n_unique_targets` and + `max_n_chans = 3`. More information on working with multivariate indices + and handling connections where the number of seeds and targets are not + equal can be found in the :doc:`../auto_examples/handling_ragged_arrays` + example. """ array_like = (np.ndarray, list, tuple) @@ -170,6 +259,9 @@ def seed_target_multivariate_indices(seeds, targets): if not isinstance(inds, array_like): raise TypeError( '`seeds` and `targets` must contain nested array-likes') + if len(inds) != len(np.unique(inds)): + raise ValueError( + '`seeds` and `targets` cannot contain repeated channels') n_chans.append(len(inds)) max_n_chans = max(n_chans) n_cons = len(seeds) * len(targets) From 005d87a0067b21e277a1aa36105df54c2f66043a Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 14 Aug 2023 17:05:04 +0200 Subject: [PATCH 08/40] bug fix utils doc update --- mne_connectivity/utils/utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index 515f7f53..22321bbe 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -108,8 +108,8 @@ def check_multivariate_indices(indices): connections or between the seeds and targets for a given connection (i.e. ragged indices), the returned array will be padded with the invalid channel index ``-1`` according to the maximum number of channels in the seed or - target of any one connection. E.g. the ragged indices of shape `(2, n_cons, - variable)`:: + target of any one connection. E.g. the ragged indices of shape ``(2, + n_cons, variable)``:: indices = ([[0, 1], [0, 1 ]], # seeds [[2, 3], [4, 5, 6]]) # targets @@ -119,8 +119,8 @@ def check_multivariate_indices(indices): indices = (np.array([[0, 1, -1], [0, 1, -1]]), # seeds np.array([[2, 3, -1], [4, 5, -1]])) # targets - where the indices have been padded with ``-1`` to have shape `(2, n_cons, - max_n_chans)`, where `max_n_chans = 3`. More information on working with + where the indices have been padded with ``-1`` to have shape ``(2, n_cons, + max_n_chans)``, where ``max_n_chans = 3``. More information on working with multivariate indices and handling connections where the number of seeds and targets are not equal can be found in the :doc:`../auto_examples/handling_ragged_arrays` example. @@ -183,8 +183,8 @@ def seed_target_indices(seeds, targets): indices = (np.array([0, 0, 0, 1, 1, 1]), # seeds np.array([2, 3, 4, 2, 3, 4])) # targets - where the indices have been expanded to have shape `(2, n_cons)`, where - `n_cons = n_unique_seeds * n_unique_targets`. + where the indices have been expanded to have shape ``(2, n_cons)``, where + ``n_cons = n_unique_seeds * n_unique_targets``. """ # make them arrays seeds = np.asarray((seeds,)).ravel() @@ -239,9 +239,9 @@ def seed_target_multivariate_indices(seeds, targets): indices = (np.array([[0, -1, -1], [0, -1, -1]]), # seeds np.array([[1, 2, -1], [3, 4, 5]])) # targets - where the indices have been padded with ``-1`` to have shape `(2, n_cons, - max_n_chans)`, where `n_cons = n_unique_seeds * n_unique_targets` and - `max_n_chans = 3`. More information on working with multivariate indices + where the indices have been padded with ``-1`` to have shape ``(2, n_cons, + max_n_chans)``, where ``n_cons = n_unique_seeds * n_unique_targets`` and + ``max_n_chans = 3``. More information on working with multivariate indices and handling connections where the number of seeds and targets are not equal can be found in the :doc:`../auto_examples/handling_ragged_arrays` example. From 96ddda31b915184af067a139b6999840f299fa32 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 14 Aug 2023 17:27:38 +0200 Subject: [PATCH 09/40] bug fix utils doc update --- mne_connectivity/utils/utils.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index 22321bbe..e7508f8f 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -170,10 +170,10 @@ def seed_target_indices(seeds, targets): Notes ----- - `seeds` and `targets` should be array-likes or integers representing the - indices of the channel pairs in the data for each connection. `seeds` and - `targets` will be expanded such that connectivity will be computed between - each seed and each target. E.g. the seeds and targets:: + ``seeds`` and ``targets`` should be array-likes or integers representing + the indices of the channel pairs in the data for each connection. ``seeds`` + and ``targets`` will be expanded such that connectivity will be computed + between each seed and each target. E.g. the seeds and targets:: seeds = [0, 1] targets = [2, 3, 4] @@ -217,19 +217,19 @@ def seed_target_multivariate_indices(seeds, targets): Notes ----- - `seeds` and `targets` should be array-likes representing the indices of the - channel sets in the data for each connection. The indices for each + ``seeds`` and ``targets`` should be array-likes representing the indices of + the channel sets in the data for each connection. The indices for each connection should be an array-like of integers representing the individual channels in the data. The length of indices for each connection do not need to be equal. Furthermore, all indices within a connection must be unique. - `seeds` and `targets` will be expanded such that connectivity will be + ``seeds`` and ``targets`` will be expanded such that connectivity will be computed between each set of seeds and targets. In case the number of channels differs across connections or between the seeds and targets for a given connection (i.e. ragged indices), the returned array will be padded with the invalid channel index ``-1`` according to the maximum number of - channels in the seed or target of any one connection. E.g. `seeds` and - `targets`:: + channels in the seed or target of any one connection. E.g. ``seeds`` and + ``targets``:: seeds = [[0]] targets = [[1, 2], [3, 4, 5]] @@ -252,16 +252,16 @@ def seed_target_multivariate_indices(seeds, targets): not isinstance(seeds, array_like) or not isinstance(targets, array_like) ): - raise TypeError('`seeds` and `targets` must be array-like') + raise TypeError('``seeds`` and ``targets`` must be array-like') n_chans = [] for inds in [*seeds, *targets]: if not isinstance(inds, array_like): raise TypeError( - '`seeds` and `targets` must contain nested array-likes') + '``seeds`` and ``targets`` must contain nested array-likes') if len(inds) != len(np.unique(inds)): raise ValueError( - '`seeds` and `targets` cannot contain repeated channels') + '``seeds`` and ``targets`` cannot contain repeated channels') n_chans.append(len(inds)) max_n_chans = max(n_chans) n_cons = len(seeds) * len(targets) From d1bcde26a5c59e5d929090e5f91ee6f5e36fe549 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 14 Aug 2023 17:42:30 +0200 Subject: [PATCH 10/40] bug fix utils doc update --- mne_connectivity/utils/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index e7508f8f..851a3e37 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -252,16 +252,16 @@ def seed_target_multivariate_indices(seeds, targets): not isinstance(seeds, array_like) or not isinstance(targets, array_like) ): - raise TypeError('``seeds`` and ``targets`` must be array-like') + raise TypeError('`seeds` and `targets` must be array-like') n_chans = [] for inds in [*seeds, *targets]: if not isinstance(inds, array_like): raise TypeError( - '``seeds`` and ``targets`` must contain nested array-likes') + '`seeds` and `targets` must contain nested array-likes') if len(inds) != len(np.unique(inds)): raise ValueError( - '``seeds`` and ``targets`` cannot contain repeated channels') + '`seeds` and `targets` cannot contain repeated channels') n_chans.append(len(inds)) max_n_chans = max(n_chans) n_cons = len(seeds) * len(targets) From 56ad89e69dbfa66005d659b769f9e7fdad64a7d6 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 28 Aug 2023 16:06:19 +0200 Subject: [PATCH 11/40] updated spectral tests --- mne_connectivity/io.py | 4 ++ .../spectral/tests/test_spectral.py | 37 ++++++++++++++++--- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/mne_connectivity/io.py b/mne_connectivity/io.py index 2c99eebc..63aa3501 100644 --- a/mne_connectivity/io.py +++ b/mne_connectivity/io.py @@ -53,6 +53,10 @@ def _xarray_to_conn(array, cls_func): event_id = dict(zip(event_id_keys, event_id_vals)) array.attrs['event_id'] = event_id + # convert indices numpy arrays to a tuple of arrays + if isinstance(array.attrs['indices'], np.ndarray): + array.attrs['indices'] = tuple(array.attrs['indices']) + # create the connectivity class conn = cls_func( data=data, names=names, metadata=metadata, **array.attrs diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index a9d6dfa6..3a41c0ed 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -1282,9 +1282,34 @@ def test_save(tmp_path): faverage=True) conn.save(tmp_path / 'foo.nc') - # multivariate connectivity - # use ragged indices & MIC to test padding of indices and patterns - indices = (np.array([[0, 1]]), np.array([[2]])) - conn_mvc = spectral_connectivity_epochs( - epochs, method="mic", indices=indices, sfreq=sfreq, fmin=10, fmax=40) - conn_mvc.save(tmp_path / 'foo_mvc.nc') + +def test_multivar_save_load(tmp_path): + """Test saving and loading results of multivariate connectivity.""" + rng = np.random.RandomState(0) + n_epochs, n_chs, n_times, sfreq, f = 10, 4, 2000, 1000., 20. + data = rng.randn(n_epochs, n_chs, n_times) + sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000) + data[:, :, 500:1500] += sig + info = create_info(n_chs, sfreq, 'eeg') + tmin = -1 + epochs = EpochsArray(data, info, tmin=tmin) + tmp_file = os.path.join(tmp_path, 'foo_mvc.nc') + + non_ragged_indices = (np.array([[0, 1]]), np.array([[2, 3]])) + ragged_indices = (np.array([[0, 1]]), np.array([[2]])) + for indices in [non_ragged_indices, ragged_indices]: + con = spectral_connectivity_epochs( + epochs, method=['mic', 'mim', 'gc', 'gc_tr'], indices=indices, + sfreq=sfreq, fmin=10, fmax=30) + for this_con in con: + this_con.save(tmp_file) + read_con = read_connectivity(tmp_file) + assert_array_almost_equal(this_con.get_data(), + read_con.get_data('raveled')) + if this_con.attrs['patterns'] is not None: + assert_array_almost_equal(np.array(this_con.attrs['patterns']), + np.array(read_con.attrs['patterns'])) + # split `repr` before the file size (`~23 kB` for example) + a = repr(this_con).split('~')[0] + b = repr(read_con).split('~')[0] + assert a == b From e1f4179513cfc8acac8cc883865c81163f8b8143 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 31 Aug 2023 11:21:15 +0200 Subject: [PATCH 12/40] added note for refactoring --- mne_connectivity/spectral/epochs.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index fc617bd7..b92af188 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -200,6 +200,12 @@ def _check_rank_input(rank, data, indices): else: data_arr = data + # XXX: Unpadding of arrays after already padding them is perhaps not so + # efficient. However, we need to remove the padded values to + # ensure only the correct channels are indexed, and having two + # versions of indices is a bit messy currently. A candidate for + # refactoring to simplify code. + for group_i in range(2): # seeds and targets for con_i, con_idcs in enumerate(indices[group_i]): con_idcs = con_idcs[con_idcs != -1] # -1 is padded value From 8a76736cbeb917d13402444f44f1a20e7aa6e555 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Fri, 1 Sep 2023 14:38:00 +0200 Subject: [PATCH 13/40] updated spectral tests --- .../spectral/tests/test_spectral.py | 47 ++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 3a41c0ed..212ca0cf 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -1286,7 +1286,7 @@ def test_save(tmp_path): def test_multivar_save_load(tmp_path): """Test saving and loading results of multivariate connectivity.""" rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq, f = 10, 4, 2000, 1000., 20. + n_epochs, n_chs, n_times, sfreq, f = 5, 4, 2000, 1000., 20. data = rng.randn(n_epochs, n_chs, n_times) sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000) data[:, :, 500:1500] += sig @@ -1313,3 +1313,48 @@ def test_multivar_save_load(tmp_path): a = repr(this_con).split('~')[0] b = repr(read_con).split('~')[0] assert a == b + + +def test_spectral_connectivity_indices_maintained(tmp_path): + """Test that indices values and type is maintained after saving. + + If `indices` is None, `indices` in the returned connectivity object should + be None, otherwise, `indices` should be a tuple. The type of `indices` and + its values should be retained after saving and reloading. + """ + rng = np.random.RandomState(0) + n_epochs, n_chs, n_times, sfreq, f = 5, 4, 2000, 1000., 20. + data = rng.randn(n_epochs, n_chs, n_times) + sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000) + data[:, :, 500:1500] += sig + info = create_info(n_chs, sfreq, 'eeg') + tmin = -1 + epochs = EpochsArray(data, info, tmin=tmin) + freqs = np.arange(10, 31) + tmp_file = os.path.join(tmp_path, 'foo_mvc.nc') + + bivar_indices = (np.array([0, 1]), np.array([2, 3])) + multivar_indices = (np.array([[0, 1]]), np.array([[2, 3]])) + indices = [None, bivar_indices, None, multivar_indices] + methods = ['coh', 'coh', 'mic', 'mic'] + + for this_indices, this_method in zip(indices, methods): + con_epochs = spectral_connectivity_epochs( + epochs, method=this_method, indices=this_indices, sfreq=sfreq, + fmin=10, fmax=30) + con_time = spectral_connectivity_time( + epochs, freqs, method=this_method, indices=this_indices, + sfreq=sfreq) + + for con in [con_epochs, con_time]: + con.save(tmp_file) + read_con = read_connectivity(tmp_file) + if this_indices is not None: + # check indices of same type (tuples) + assert (isinstance(con.indices, tuple) and + isinstance(read_con.indices, tuple)) + # check indices have same values + assert np.all(np.array(con.indices) == + np.array(read_con.indices)) + else: + assert con.indices is None and read_con.indices is None From 527aaf8845226e1ff1f8bd47b0f80b8395415b5b Mon Sep 17 00:00:00 2001 From: Adam Li Date: Fri, 1 Sep 2023 13:49:14 -0600 Subject: [PATCH 14/40] Update ignore words Signed-off-by: Adam Li --- ignore_words.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ignore_words.txt b/ignore_words.txt index 6fc27590..1375fa21 100644 --- a/ignore_words.txt +++ b/ignore_words.txt @@ -1,2 +1,4 @@ nd adn +ba +BA \ No newline at end of file From 5f72b6a5379e169cc55bb89a871851b426761493 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Fri, 8 Sep 2023 19:16:42 +0200 Subject: [PATCH 15/40] added error message --- mne_connectivity/spectral/epochs.py | 12 +++++++----- mne_connectivity/spectral/time.py | 12 +++++++----- mne_connectivity/tests/test_utils.py | 15 +++++++++++++++ mne_connectivity/utils/utils.py | 12 +++++++++++- 4 files changed, 40 insertions(+), 11 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index b92af188..cc9016e7 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -1569,11 +1569,13 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, 'mim', 'gc', 'gc_tr]``) cannot be called with the other methods. indices : tuple of array | None Two arrays with indices of connections for which to compute - connectivity. If a multivariate method is called, each array for the - seeds and targets should contain a nested array of channel indices for - the individual connections. If None, connections between all channels - are computed, unless a Granger causality method is called, in which - case an error is raised. + connectivity. If a bivariate method is called, each array for the seeds + and targets should contain the channel indices for the each bivariate + connection. If a multivariate method is called, each array for the + seeds and targets should consist of nested arrays containing + the channel indices for each multivariate connection. If None, + connections between all channels are computed, unless a Granger + causality method is called, in which case an error is raised. sfreq : float The sampling frequency. Required if data is not :class:`Epochs `. diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 6b5eb000..3798f699 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -70,11 +70,13 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, :class:`EpochSpectralConnectivity`. indices : tuple of array_like | None Two arrays with indices of connections for which to compute - connectivity. If a multivariate method is called, each array for the - seeds and targets should contain a nested array of channel indices for - the individual connections. If None, connections between all channels - are computed, unless a Granger causality method is called, in which - case an error is raised. + connectivity. If a bivariate method is called, each array for the seeds + and targets should contain the channel indices for the each bivariate + connection. If a multivariate method is called, each array for the + seeds and targets should consist of nested arrays containing + the channel indices for each multivariate connection. If None, + connections between all channels are computed, unless a Granger + causality method is called, in which case an error is raised. sfreq : float The sampling frequency. Required if data is not :class:`Epochs `. diff --git a/mne_connectivity/tests/test_utils.py b/mne_connectivity/tests/test_utils.py index 0d678da2..1e5822eb 100644 --- a/mne_connectivity/tests/test_utils.py +++ b/mne_connectivity/tests/test_utils.py @@ -72,6 +72,10 @@ def test_check_indices(): with pytest.raises(ValueError, match='Index arrays indices'): non_equal_len_indices = ([0], [1, 2]) check_indices(non_equal_len_indices) + with pytest.raises(TypeError, + match='Channel indices must be integers, not array'): + nested_indices = ([[0]], [[1]]) + check_indices(nested_indices) # multivariate indices # non-ragged indices @@ -87,6 +91,17 @@ def test_check_indices(): assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]), np.array([[2, 3, 4], [4, -1, -1]]))) # test error catching + with pytest.raises(ValueError, + match='indices must be a tuple of length 2'): + non_tuple_indices = [np.array([0, 1]), np.array([2, 3])] + check_multivariate_indices(non_tuple_indices) + with pytest.raises(ValueError, + match='indices must be a tuple of length 2'): + non_len2_indices = (np.array([0]), np.array([1]), np.array([2])) + check_multivariate_indices(non_len2_indices) + with pytest.raises(ValueError, match='index arrays indices'): + non_equal_len_indices = (np.array([[0]]), np.array([[1], [2]])) + check_multivariate_indices(non_equal_len_indices) with pytest.raises(TypeError, match='multivariate indices must contain array-likes'): non_nested_indices = (np.array([0, 1]), np.array([2, 3])) diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index 851a3e37..550fbe68 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -76,6 +76,10 @@ def check_indices(indices): raise ValueError('Index arrays indices[0] and indices[1] must ' 'have the same length') + if any(isinstance(inds, (np.ndarray, list, tuple)) for inds in + [*indices[0], *indices[1]]): + raise TypeError('Channel indices must be integers, not array-likes') + return indices @@ -125,7 +129,13 @@ def check_multivariate_indices(indices): targets are not equal can be found in the :doc:`../auto_examples/handling_ragged_arrays` example. """ - indices = check_indices(indices) + if not isinstance(indices, tuple) or len(indices) != 2: + raise ValueError('indices must be a tuple of length 2') + + if len(indices[0]) != len(indices[1]): + raise ValueError('index arrays indices[0] and indices[1] must ' + 'have the same length') + n_cons = len(indices[0]) n_chans = [] From ea03b902df81f46253cba63820e8ce560ea4ee0b Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Sat, 9 Sep 2023 12:31:44 +0200 Subject: [PATCH 16/40] Added formatting suggestions Co-authored-by: Daniel McCloy --- mne_connectivity/spectral/epochs.py | 4 ++-- mne_connectivity/utils/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index cc9016e7..5158b3c0 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -1570,10 +1570,10 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, indices : tuple of array | None Two arrays with indices of connections for which to compute connectivity. If a bivariate method is called, each array for the seeds - and targets should contain the channel indices for the each bivariate + and targets should contain the channel indices for each bivariate connection. If a multivariate method is called, each array for the seeds and targets should consist of nested arrays containing - the channel indices for each multivariate connection. If None, + the channel indices for each multivariate connection. If ``None``, connections between all channels are computed, unless a Granger causality method is called, in which case an error is raised. sfreq : float diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index 550fbe68..58905ac7 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -50,7 +50,7 @@ def par(x): def check_indices(indices): - """Check indices parameter for bivariate connetivity. + """Check indices parameter for bivariate connectivity. Parameters ---------- From b3b0f97e04aceecaa906a1436dfb1c97c30ba035 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Sat, 9 Sep 2023 12:46:06 +0200 Subject: [PATCH 17/40] added max_n_chans suggestion --- mne_connectivity/utils/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index 550fbe68..8cccae25 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -138,7 +138,7 @@ def check_multivariate_indices(indices): n_cons = len(indices[0]) - n_chans = [] + max_n_chans = 0 for inds in ([*indices[0], *indices[1]]): if not isinstance(inds, (np.ndarray, list, tuple)): raise TypeError( @@ -148,8 +148,7 @@ def check_multivariate_indices(indices): raise ValueError( 'multivariate indices cannot contain repeated channels within ' 'a seed or target') - n_chans.append(len(inds)) - max_n_chans = np.max(n_chans) + max_n_chans = max(max_n_chans, len(inds)) # pad indices to avoid ragged arrays padded_indices = (np.full((n_cons, max_n_chans), -1, dtype=np.int32), From eda3c6bb07cb5647cfdc221c7a63819e21da5010 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Sat, 9 Sep 2023 13:00:12 +0200 Subject: [PATCH 18/40] updated epochs docstring --- mne_connectivity/spectral/epochs.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 5158b3c0..7ad551c9 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -1641,13 +1641,15 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, con : array | list of array Computed connectivity measure(s). Either an instance of ``SpectralConnectivity`` or ``SpectroTemporalConnectivity``. - The shape of each connectivity dataset is either: - (n_cons, n_freqs) mode: 'multitaper' or 'fourier'; or - (n_cons, n_freqs, n_times) mode: 'cwt_morlet'. When "indices" is None - and a bivariate method is called, "n_cons = n_signals ** 2", or if a - multivariate method is called "n_cons = 1". When "indices" is - specified, "n_con = len(indices[0])" for bivariate and multivariate - methods. + The shape of the connectivity result will be: + + - ``(n_cons, n_freqs)`` for multitaper or fourier modes + - ``(n_cons, n_freqs, n_times)`` for cwt_morlet mode + - ``n_cons = n_signals ** 2`` for bivariate methods with + ``indices=None`` + - ``n_cons = 1`` for multivariate methods with ``indices=None`` + - ``n_cons = len(indices[0])`` for bivariate and multivariate methods + when indices is supplied. See Also -------- From 437e72a006005e9f9638c106d6a99798a9acf8c6 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Sat, 9 Sep 2023 13:03:57 +0200 Subject: [PATCH 19/40] added test suggestion Co-authored-by: Adam Li --- .../spectral/tests/test_spectral.py | 69 ++++++++++--------- 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 212ca0cf..3f7416f1 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -1315,7 +1315,12 @@ def test_multivar_save_load(tmp_path): assert a == b -def test_spectral_connectivity_indices_maintained(tmp_path): +@pytest.mark.parametrize("method", ["coh", "plv", "pli", "wpli", "ciplv", "mic", "mim"]) +@pytest.mark.parametrize("indices", [None, + (np.array([0, 1]), np.array([2, 3])), + (np.array([[0, 1]]), np.array([[2, 3]])) + ]) +def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): """Test that indices values and type is maintained after saving. If `indices` is None, `indices` in the returned connectivity object should @@ -1323,38 +1328,40 @@ def test_spectral_connectivity_indices_maintained(tmp_path): its values should be retained after saving and reloading. """ rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq, f = 5, 4, 2000, 1000., 20. + n_epochs, n_chs, n_times, sfreq, f = 5, 4, 200, 100.0, 20.0 data = rng.randn(n_epochs, n_chs, n_times) - sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000) - data[:, :, 500:1500] += sig - info = create_info(n_chs, sfreq, 'eeg') + info = create_info(n_chs, sfreq, "eeg") tmin = -1 epochs = EpochsArray(data, info, tmin=tmin) freqs = np.arange(10, 31) - tmp_file = os.path.join(tmp_path, 'foo_mvc.nc') + tmp_file = os.path.join(tmp_path, "foo_mvc.nc") - bivar_indices = (np.array([0, 1]), np.array([2, 3])) - multivar_indices = (np.array([[0, 1]]), np.array([[2, 3]])) - indices = [None, bivar_indices, None, multivar_indices] - methods = ['coh', 'coh', 'mic', 'mic'] - - for this_indices, this_method in zip(indices, methods): - con_epochs = spectral_connectivity_epochs( - epochs, method=this_method, indices=this_indices, sfreq=sfreq, - fmin=10, fmax=30) - con_time = spectral_connectivity_time( - epochs, freqs, method=this_method, indices=this_indices, - sfreq=sfreq) - - for con in [con_epochs, con_time]: - con.save(tmp_file) - read_con = read_connectivity(tmp_file) - if this_indices is not None: - # check indices of same type (tuples) - assert (isinstance(con.indices, tuple) and - isinstance(read_con.indices, tuple)) - # check indices have same values - assert np.all(np.array(con.indices) == - np.array(read_con.indices)) - else: - assert con.indices is None and read_con.indices is None + # mutlivariate methods and bivariate methods require the right indices shape + if method in ["mic", "mim"]: + if indices is not None and indices[0].ndim == 1: + pytest.skip() + else: + if indices is not None and indices[0].ndim == 2: + pytest.skip() + + # actually test the pair of method and indices defined to check the output indices + con_epochs = spectral_connectivity_epochs( + epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30 + ) + con_time = spectral_connectivity_time( + epochs, freqs, method=method, indices=indices, sfreq=sfreq + ) + + for con in [con_epochs, con_time]: + con.save(tmp_file) + read_con = read_connectivity(tmp_file) + + if indices is not None: + # check indices of same type (tuples) + assert isinstance(con.indices, tuple) and isinstance( + read_con.indices, tuple + ) + # check indices have same values + assert np.all(np.array(con.indices) == np.array(read_con.indices)) + else: + assert con.indices is None and read_con.indices is None From 8844afc8544e2f14607c1346e631fd0effe1ef24 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Sat, 9 Sep 2023 13:16:54 +0200 Subject: [PATCH 20/40] fixed style errors --- mne_connectivity/spectral/tests/test_spectral.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 3f7416f1..592291f0 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -1315,8 +1315,9 @@ def test_multivar_save_load(tmp_path): assert a == b -@pytest.mark.parametrize("method", ["coh", "plv", "pli", "wpli", "ciplv", "mic", "mim"]) -@pytest.mark.parametrize("indices", [None, +@pytest.mark.parametrize("method", ["coh", "plv", "pli", "wpli", "ciplv", + "mic", "mim"]) +@pytest.mark.parametrize("indices", [None, (np.array([0, 1]), np.array([2, 3])), (np.array([[0, 1]]), np.array([[2, 3]])) ]) @@ -1328,7 +1329,7 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): its values should be retained after saving and reloading. """ rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq, f = 5, 4, 200, 100.0, 20.0 + n_epochs, n_chs, n_times, sfreq = 5, 4, 200, 100.0 data = rng.randn(n_epochs, n_chs, n_times) info = create_info(n_chs, sfreq, "eeg") tmin = -1 @@ -1336,15 +1337,15 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): freqs = np.arange(10, 31) tmp_file = os.path.join(tmp_path, "foo_mvc.nc") - # mutlivariate methods and bivariate methods require the right indices shape + # mutlivariate and bivariate methods require the right indices shape if method in ["mic", "mim"]: if indices is not None and indices[0].ndim == 1: pytest.skip() else: if indices is not None and indices[0].ndim == 2: pytest.skip() - - # actually test the pair of method and indices defined to check the output indices + + # test the pair of method and indices defined to check the output indices con_epochs = spectral_connectivity_epochs( epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30 ) @@ -1355,7 +1356,7 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): for con in [con_epochs, con_time]: con.save(tmp_file) read_con = read_connectivity(tmp_file) - + if indices is not None: # check indices of same type (tuples) assert isinstance(con.indices, tuple) and isinstance( From 028543973496da94f5e351f5c854c31922bdaad8 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 24 Oct 2023 17:36:19 +0200 Subject: [PATCH 21/40] Squashed commit of the following: commit 36f88de81dcc8e88afddbfb7b3bd3cb432aba535 Merge: 824dbcc 7b70d0b Author: Thomas Samuel Binns Date: Tue Oct 24 17:35:30 2023 +0200 Merge branch 'hackathon_2023' of https://github.com/tsbinns/mne-connectivity into hackathon_2023 commit 824dbcc72d3bc7b2256e9e1fcb3bab2883fae8b3 Author: Thomas Samuel Binns Date: Mon Oct 23 18:58:14 2023 +0200 refactor epochs functions commit 7b70d0bbdbdf22a2525e5cfde4619300f33f86ae Author: Thomas Samuel Binns Date: Tue Oct 24 17:34:19 2023 +0200 bug fix refactoring commit dd46076ffbe566a556bf21e14485baac4a9c4001 Author: Thomas Samuel Binns Date: Tue Oct 24 17:21:07 2023 +0200 bug fix refactoring commit 71fa86c45dc62f1194e132d5b97e21ad17c851ce Author: Thomas Samuel Binns Date: Tue Oct 24 16:59:30 2023 +0200 bug fix refactoring commit 2ce29eb9ae98f0b7f6fcb39f182c7185624a88ea Author: Thomas Samuel Binns Date: Tue Oct 24 16:44:02 2023 +0200 bug fix refactoring commit ffdf82fbd896691b18fc4f737eec02e58029ec3a Author: Thomas Samuel Binns Date: Tue Oct 24 14:39:34 2023 +0200 refactor new funcs commit 0c4d5c6fdef65651f5d71868d2bdb763207b038f Author: Thomas Samuel Binns Date: Tue Oct 24 14:26:36 2023 +0200 bug fix refactoring commit 3e93cea72bfaa109f7b6f5d7efe861e32d125766 Author: Thomas Samuel Binns Date: Tue Oct 24 11:55:24 2023 +0200 refactor new funcs commit 174efc7418c75924b4ae8905389758f1f4de8d5c Author: Thomas Samuel Binns Date: Mon Oct 23 22:24:32 2023 +0200 refactor new funcs commit 5d5d74dbe046e776a7127bc499f179bebf1b2596 Author: Thomas Samuel Binns Date: Mon Oct 23 19:04:04 2023 +0200 refactor new funcs commit 014c22e74ffd5699da0c803f22912b32b04b2ace Author: Thomas Samuel Binns Date: Mon Oct 23 19:03:18 2023 +0200 refactor new funcs commit 6282e678d76ebbf4a7d161f4bfd7a59dd0f427c4 Author: Thomas Samuel Binns Date: Mon Oct 23 18:58:14 2023 +0200 refactor new funcs --- doc/api.rst | 1 + examples/granger_causality.py | 14 +- examples/mic_mim.py | 10 +- mne_connectivity/__init__.py | 4 +- mne_connectivity/spectral/__init__.py | 3 +- mne_connectivity/spectral/epochs.py | 1693 +---------------- mne_connectivity/spectral/epochs_bivariate.py | 729 +++++++ .../spectral/epochs_multivariate.py | 1129 +++++++++++ .../spectral/tests/test_spectral.py | 158 +- mne_connectivity/spectral/time.py | 5 +- 10 files changed, 2053 insertions(+), 1693 deletions(-) create mode 100644 mne_connectivity/spectral/epochs_bivariate.py create mode 100644 mne_connectivity/spectral/epochs_multivariate.py diff --git a/doc/api.rst b/doc/api.rst index c91f9c02..3fe85832 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -47,6 +47,7 @@ on numpy array inputs. phase_slope_index vector_auto_regression spectral_connectivity_epochs + spectral_connectivity_epochs_multivariate spectral_connectivity_time Reading functions diff --git a/examples/granger_causality.py b/examples/granger_causality.py index 64a657db..4129dadc 100644 --- a/examples/granger_causality.py +++ b/examples/granger_causality.py @@ -20,7 +20,7 @@ import mne from mne.datasets.fieldtrip_cmc import data_path -from mne_connectivity import spectral_connectivity_epochs +from mne_connectivity import spectral_connectivity_epochs_multivariate ############################################################################### # Background @@ -161,10 +161,10 @@ indices_ba = (np.array([signals_b]), np.array([signals_a])) # B => A # compute Granger causality -gc_ab = spectral_connectivity_epochs( +gc_ab = spectral_connectivity_epochs_multivariate( epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=(np.array([5]), np.array([5])), gc_n_lags=20) # A => B -gc_ba = spectral_connectivity_epochs( +gc_ba = spectral_connectivity_epochs_multivariate( epochs, method=['gc'], indices=indices_ba, fmin=5, fmax=30, rank=(np.array([5]), np.array([5])), gc_n_lags=20) # B => A freqs = gc_ab.freqs @@ -262,10 +262,10 @@ # %% # compute GC on time-reversed signals -gc_tr_ab = spectral_connectivity_epochs( +gc_tr_ab = spectral_connectivity_epochs_multivariate( epochs, method=['gc_tr'], indices=indices_ab, fmin=5, fmax=30, rank=(np.array([5]), np.array([5])), gc_n_lags=20) # TR[A => B] -gc_tr_ba = spectral_connectivity_epochs( +gc_tr_ba = spectral_connectivity_epochs_multivariate( epochs, method=['gc_tr'], indices=indices_ba, fmin=5, fmax=30, rank=(np.array([5]), np.array([5])), gc_n_lags=20) # TR[B => A] @@ -317,7 +317,7 @@ # %% -gc_ab_60 = spectral_connectivity_epochs( +gc_ab_60 = spectral_connectivity_epochs_multivariate( epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=(np.array([5]), np.array([5])), gc_n_lags=60) # A => B @@ -375,7 +375,7 @@ # %% try: - spectral_connectivity_epochs( + spectral_connectivity_epochs_multivariate( epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=None, gc_n_lags=20, verbose=False) # A => B print('Success!') diff --git a/examples/mic_mim.py b/examples/mic_mim.py index 87111586..62674e75 100644 --- a/examples/mic_mim.py +++ b/examples/mic_mim.py @@ -25,7 +25,9 @@ import mne from mne import EvokedArray, make_fixed_length_epochs from mne.datasets.fieldtrip_cmc import data_path -from mne_connectivity import seed_target_indices, spectral_connectivity_epochs +from mne_connectivity import (seed_target_indices, + spectral_connectivity_epochs, + spectral_connectivity_epochs_multivariate) ############################################################################### # Background @@ -87,7 +89,7 @@ target_names = [epochs.info['ch_names'][idx] for idx in targets] # multivariate imaginary part of coherency -(mic, mim) = spectral_connectivity_epochs( +(mic, mim) = spectral_connectivity_epochs_multivariate( epochs, method=['mic', 'mim'], indices=multivar_indices, fmin=5, fmax=30, rank=None) @@ -290,7 +292,7 @@ # %% indices = (np.array([[*seeds, *targets]]), np.array([[*seeds, *targets]])) -gim = spectral_connectivity_epochs( +gim = spectral_connectivity_epochs_multivariate( epochs, method='mim', indices=indices, fmin=5, fmax=30, rank=None, verbose=False) @@ -342,7 +344,7 @@ # %% -(mic_red, mim_red) = spectral_connectivity_epochs( +(mic_red, mim_red) = spectral_connectivity_epochs_multivariate( epochs, method=['mic', 'mim'], indices=multivar_indices, fmin=5, fmax=30, rank=([25], [25])) diff --git a/mne_connectivity/__init__.py b/mne_connectivity/__init__.py index c2f03a6c..32488b33 100644 --- a/mne_connectivity/__init__.py +++ b/mne_connectivity/__init__.py @@ -15,7 +15,9 @@ from .effective import phase_slope_index from .envelope import envelope_correlation, symmetric_orth from .io import read_connectivity -from .spectral import spectral_connectivity_time, spectral_connectivity_epochs +from .spectral import (spectral_connectivity_time, + spectral_connectivity_epochs, + spectral_connectivity_epochs_multivariate) from .vector_ar import vector_auto_regression, select_order from .utils import (check_indices, check_multivariate_indices, degree, seed_target_indices, seed_target_multivariate_indices) diff --git a/mne_connectivity/spectral/__init__.py b/mne_connectivity/spectral/__init__.py index a0f06ef6..f2252db9 100644 --- a/mne_connectivity/spectral/__init__.py +++ b/mne_connectivity/spectral/__init__.py @@ -1,2 +1,3 @@ -from .epochs import spectral_connectivity_epochs +from .epochs_bivariate import spectral_connectivity_epochs +from .epochs_multivariate import spectral_connectivity_epochs_multivariate from .time import spectral_connectivity_time \ No newline at end of file diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 7ad551c9..5742ae33 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -2,8 +2,6 @@ # Denis A. Engemann # Adam Li # Thomas S. Binns -# Tien D. Nguyen -# Richard M. Köhler # # License: BSD (3-clause) @@ -11,20 +9,16 @@ import inspect import numpy as np -import scipy as sp from mne.epochs import BaseEpochs from mne.parallel import parallel_func from mne.source_estimate import _BaseSourceEstimate -from mne.time_frequency.multitaper import (_csd_from_mt, - _mt_spectra, _psd_from_mt, - _psd_from_mt_adaptive) +from mne.time_frequency.multitaper import ( + _csd_from_mt, _mt_spectra, _psd_from_mt, _psd_from_mt_adaptive) from mne.time_frequency.tfr import cwt, morlet from mne.time_frequency.multitaper import _compute_mt_params -from mne.utils import ( - ProgressBar, _arange_div, _check_option, _time_mask, logger, warn, verbose) +from mne.utils import _arange_div, _check_option, _time_mask, logger, warn -from ..base import (SpectralConnectivity, SpectroTemporalConnectivity) -from ..utils import fill_doc, check_indices, check_multivariate_indices +from ..base import SpectralConnectivity, SpectroTemporalConnectivity def _compute_freqs(n_times, sfreq, cwt_freqs, mode): @@ -63,10 +57,8 @@ def _compute_freq_mask(freqs_all, fmin, fmax, fskip): return freq_mask -def _prepare_connectivity(epoch_block, times_in, tmin, tmax, - fmin, fmax, sfreq, indices, - method, mode, fskip, n_bands, - cwt_freqs, faverage): +def _prepare_connectivity(epoch_block, times_in, tmin, tmax, fmin, fmax, sfreq, + mode, fskip, n_bands, cwt_freqs, faverage): """Check and precompute dimensions of results data.""" first_epoch = epoch_block[0] @@ -92,43 +84,6 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, times = times_in[tmin_idx:tmax_idx] n_times = len(times) - if any(this_method in _multivariate_methods for this_method in method): - multivariate_con = True - else: - multivariate_con = False - - if indices is None: - if multivariate_con: - if any(this_method in _gc_methods for this_method in method): - raise ValueError( - 'indices must be specified when computing Granger ' - 'causality, as all-to-all connectivity is not supported') - else: - logger.info('using all indices for multivariate connectivity') - indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], - np.arange(n_signals, dtype=int)[np.newaxis, :]) - else: - logger.info('only using indices for lower-triangular matrix') - # only compute r for lower-triangular region - indices_use = np.tril_indices(n_signals, -1) - else: - if multivariate_con: - indices_use = check_multivariate_indices(indices) # pad with -1 - if any(this_method in _gc_methods for this_method in method): - for seed, target in zip(indices[0], indices[1]): - intersection = np.intersect1d(seed, target) - if np.any(intersection != -1): # ignore padded entries - raise ValueError( - 'seed and target indices must not intersect when ' - 'computing Granger causality') - else: - indices_use = check_indices(indices) - - # number of connectivities to compute - n_cons = len(indices_use[0]) - - logger.info(' computing connectivity for %d connections' - % n_cons) logger.info(' using t=%0.3fs..%0.3fs for estimation (%d points)' % (tmin_true, tmax_true, n_times)) @@ -184,55 +139,9 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, logger.info(' connectivity scores will be averaged for ' 'each band') - return (n_cons, times, n_times, times_in, n_times_in, tmin_idx, + return (times, n_times, times_in, n_times_in, tmin_idx, tmax_idx, n_freqs, freq_mask, freqs, freqs_bands, freq_idx_bands, - n_signals, indices_use, warn_times) - - -def _check_rank_input(rank, data, indices): - """Check the rank argument is appropriate and compute rank if missing.""" - sv_tol = 1e-10 # tolerance for non-zero singular val (rel to largest) - if rank is None: - rank = np.zeros((2, len(indices[0])), dtype=int) - - if isinstance(data, BaseEpochs): - data_arr = data.get_data() - else: - data_arr = data - - # XXX: Unpadding of arrays after already padding them is perhaps not so - # efficient. However, we need to remove the padded values to - # ensure only the correct channels are indexed, and having two - # versions of indices is a bit messy currently. A candidate for - # refactoring to simplify code. - - for group_i in range(2): # seeds and targets - for con_i, con_idcs in enumerate(indices[group_i]): - con_idcs = con_idcs[con_idcs != -1] # -1 is padded value - s = np.linalg.svd(data_arr[:, con_idcs], compute_uv=False) - rank[group_i][con_i] = np.min( - [np.count_nonzero(epoch >= epoch[0] * sv_tol) - for epoch in s]) - - logger.info('Estimated data ranks:') - con_i = 1 - for seed_rank, target_rank in zip(rank[0], rank[1]): - logger.info(' connection %i - seeds (%i); targets (%i)' - % (con_i, seed_rank, target_rank, )) - con_i += 1 - rank = tuple((np.array(rank[0]), np.array(rank[1]))) - - else: - for seed_idcs, target_idcs, seed_rank, target_rank in zip( - indices[0], indices[1], rank[0], rank[1]): - if not (0 < seed_rank <= len(seed_idcs) and - 0 < target_rank <= len(target_idcs)): - raise ValueError( - 'ranks for seeds and targets must be > 0 and <= the ' - 'number of channels in the seeds and targets, ' - 'respectively, for each connection') - - return rank + n_signals, warn_times) def _assemble_spectral_params(mode, n_times, mt_adaptive, mt_bandwidth, sfreq, @@ -274,6 +183,46 @@ def _assemble_spectral_params(mode, n_times, mt_adaptive, mt_bandwidth, sfreq, return spectral_params, mt_adaptive, n_times_spectrum, n_tapers +def _compute_spectral_methods_epochs( + con_methods, epoch_block, epoch_idx, call_params, parallel, + my_spectral_connectivity_epochs, n_jobs, n_times_in, times_in, + warn_times +): + """Compute CSD/PSD for spectral_connectivity_epochs... functions.""" + # check dimensions and time scale + for this_epoch in epoch_block: + _, _, _, warn_times = _get_and_verify_data_sizes( + this_epoch, call_params["sfreq"], call_params["n_signals"], + n_times_in, times_in, warn_times=warn_times) + + if n_jobs == 1: + # no parallel processing + for this_epoch in epoch_block: + logger.info(' computing cross-spectral density for epoch %d' + % (epoch_idx + 1)) + # con methods and psd are updated inplace + _epoch_spectral_connectivity(data=this_epoch, **call_params) + epoch_idx += 1 + else: + # process epochs in parallel + logger.info( + ' computing cross-spectral density for epochs %d..%d' + % (epoch_idx + 1, epoch_idx + len(epoch_block))) + + out = parallel(my_spectral_connectivity_epochs( + data=this_epoch, **call_params) + for this_epoch in epoch_block) + # do the accumulation + for this_out in out: + for _method, parallel_method in zip(con_methods, this_out[0]): + _method.combine(parallel_method) + if call_params["psd"] is not None: + call_params["psd"] += this_out[1] + + epoch_idx += len(epoch_block) + + return epoch_idx + ######################################################################## # Various connectivity estimators @@ -293,996 +242,9 @@ def combine(self, other): def compute_con(self, con_idx, n_epochs): raise NotImplementedError('compute_con method not implemented') - -class _EpochMeanConEstBase(_AbstractConEstBase): - """Base class for methods that estimate connectivity as mean epoch-wise.""" - - patterns = None - - def __init__(self, n_cons, n_freqs, n_times): - self.n_cons = n_cons - self.n_freqs = n_freqs - self.n_times = n_times - - if n_times == 0: - self.csd_shape = (n_cons, n_freqs) - else: - self.csd_shape = (n_cons, n_freqs, n_times) - - self.con_scores = None - - def start_epoch(self): # noqa: D401 - """Called at the start of each epoch.""" - pass # for this type of con. method we don't do anything - - def combine(self, other): - """Include con. accumated for some epochs in this estimate.""" - self._acc += other._acc - - -class _EpochMeanMultivariateConEstBase(_AbstractConEstBase): - """Base class for mean epoch-wise multivar. con. estimation methods.""" - - n_steps = None - patterns = None - - def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): - self.n_signals = n_signals - self.n_cons = n_cons - self.n_freqs = n_freqs - self.n_times = n_times - self.n_jobs = n_jobs - - # include time dimension, even when unused for indexing flexibility - if n_times == 0: - self.csd_shape = (n_signals**2, n_freqs) - self.con_scores = np.zeros((n_cons, n_freqs, 1)) - else: - self.csd_shape = (n_signals**2, n_freqs, n_times) - self.con_scores = np.zeros((n_cons, n_freqs, n_times)) - - # allocate space for accumulation of CSD - self._acc = np.zeros(self.csd_shape, dtype=np.complex128) - - self._compute_n_progress_bar_steps() - - def start_epoch(self): # noqa: D401 - """Called at the start of each epoch.""" - pass # for this type of con. method we don't do anything - - def combine(self, other): - """Include con. accumulated for some epochs in this estimate.""" - self._acc += other._acc - - def accumulate(self, con_idx, csd_xy): - """Accumulate CSD for some connections.""" - self._acc[con_idx] += csd_xy - - def _compute_n_progress_bar_steps(self): - """Calculate the number of steps to include in the progress bar.""" - self.n_steps = int(np.ceil(self.n_freqs / self.n_jobs)) - - def _log_connection_number(self, con_i): - """Log the number of the connection being computed.""" - logger.info('Computing %s for connection %i of %i' - % (self.name, con_i + 1, self.n_cons, )) - - def _get_block_indices(self, block_i, limit): - """Get indices for a computation block capped by a limit.""" - indices = np.arange(block_i * self.n_jobs, (block_i + 1) * self.n_jobs) - - return indices[np.nonzero(indices < limit)] - - def reshape_csd(self): - """Reshape CSD into a matrix of times x freqs x signals x signals.""" - if self.n_times == 0: - return (np.reshape(self._acc, ( - self.n_signals, self.n_signals, self.n_freqs, 1) - ).transpose(3, 2, 0, 1)) - - return (np.reshape(self._acc, ( - self.n_signals, self.n_signals, self.n_freqs, self.n_times) - ).transpose(3, 2, 0, 1)) - - -class _CohEstBase(_EpochMeanConEstBase): - """Base Estimator for Coherence, Coherency, Imag. Coherence.""" - - accumulate_psd = True - - def __init__(self, n_cons, n_freqs, n_times): - super(_CohEstBase, self).__init__(n_cons, n_freqs, n_times) - - # allocate space for accumulation of CSD - self._acc = np.zeros(self.csd_shape, dtype=np.complex128) - - def accumulate(self, con_idx, csd_xy): - """Accumulate CSD for some connections.""" - self._acc[con_idx] += csd_xy - - -class _CohEst(_CohEstBase): - """Coherence Estimator.""" - - name = 'Coherence' - - def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - csd_mean = self._acc[con_idx] / n_epochs - self.con_scores[con_idx] = np.abs(csd_mean) / np.sqrt(psd_xx * psd_yy) - - -class _CohyEst(_CohEstBase): - """Coherency Estimator.""" - - name = 'Coherency' - - def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape, - dtype=np.complex128) - csd_mean = self._acc[con_idx] / n_epochs - self.con_scores[con_idx] = csd_mean / np.sqrt(psd_xx * psd_yy) - - -class _ImCohEst(_CohEstBase): - """Imaginary Coherence Estimator.""" - - name = 'Imaginary Coherence' - - def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - csd_mean = self._acc[con_idx] / n_epochs - self.con_scores[con_idx] = np.imag(csd_mean) / np.sqrt(psd_xx * psd_yy) - - -class _MultivariateCohEstBase(_EpochMeanMultivariateConEstBase): - """Base estimator for multivariate imag. part of coherency methods. - - See Ewald et al. (2012). NeuroImage. DOI: 10.1016/j.neuroimage.2011.11.084 - for equation references. - """ - - name = None - accumulate_psd = False - - def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): - super(_MultivariateCohEstBase, self).__init__( - n_signals, n_cons, n_freqs, n_times, n_jobs) - - def compute_con(self, indices, ranks, n_epochs=1): - """Compute multivariate imag. part of coherency between signals.""" - assert self.name in ['MIC', 'MIM'], ( - 'the class name is not recognised, please contact the ' - 'mne-connectivity developers') - - csd = self.reshape_csd() / n_epochs - n_times = csd.shape[0] - times = np.arange(n_times) - freqs = np.arange(self.n_freqs) - - if self.name == 'MIC': - self.patterns = np.full( - (2, self.n_cons, indices[0].shape[1], self.n_freqs, n_times), - np.nan) - - con_i = 0 - for seed_idcs, target_idcs, seed_rank, target_rank in zip( - indices[0], indices[1], ranks[0], ranks[1]): - self._log_connection_number(con_i) - - seed_idcs = seed_idcs[seed_idcs != -1] - target_idcs = target_idcs[target_idcs != -1] - con_idcs = [*seed_idcs, *target_idcs] - - C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] - - # Eqs. 32 & 33 - C_bar, U_bar_aa, U_bar_bb = self._csd_svd( - C, seed_idcs, seed_rank, target_rank) - - # Eqs. 3 & 4 - E = self._compute_e(C_bar, n_seeds=U_bar_aa.shape[3]) - - if self.name == 'MIC': - self._compute_mic(E, C, seed_idcs, target_idcs, n_times, - U_bar_aa, U_bar_bb, con_i) - else: - self._compute_mim(E, seed_idcs, target_idcs, con_i) - - con_i += 1 - - self.reshape_results() - - def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): - """Dimensionality reduction of CSD with SVD.""" - n_times = csd.shape[0] - n_seeds = len(seed_idcs) - n_targets = csd.shape[3] - n_seeds - - C_aa = csd[..., :n_seeds, :n_seeds] - C_ab = csd[..., :n_seeds, n_seeds:] - C_bb = csd[..., n_seeds:, n_seeds:] - C_ba = csd[..., n_seeds:, :n_seeds] - - # Eq. 32 - if seed_rank != n_seeds: - U_aa = np.linalg.svd(np.real(C_aa), full_matrices=False)[0] - U_bar_aa = U_aa[..., :seed_rank] - else: - U_bar_aa = np.broadcast_to( - np.identity(n_seeds), - (n_times, self.n_freqs) + (n_seeds, n_seeds)) - - if target_rank != n_targets: - U_bb = np.linalg.svd(np.real(C_bb), full_matrices=False)[0] - U_bar_bb = U_bb[..., :target_rank] - else: - U_bar_bb = np.broadcast_to( - np.identity(n_targets), - (n_times, self.n_freqs) + (n_targets, n_targets)) - - # Eq. 33 - C_bar_aa = np.matmul( - U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_aa, U_bar_aa)) - C_bar_ab = np.matmul( - U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_ab, U_bar_bb)) - C_bar_bb = np.matmul( - U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_bb, U_bar_bb)) - C_bar_ba = np.matmul( - U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_ba, U_bar_aa)) - C_bar = np.append(np.append(C_bar_aa, C_bar_ab, axis=3), - np.append(C_bar_ba, C_bar_bb, axis=3), axis=2) - - return C_bar, U_bar_aa, U_bar_bb - - def _compute_e(self, csd, n_seeds): - """Compute E from the CSD.""" - C_r = np.real(csd) - - parallel, parallel_compute_t, _ = parallel_func( - _mic_mim_compute_t, self.n_jobs, verbose=False) - - # imag. part of T filled when data is rank-deficient - T = np.zeros(csd.shape, dtype=np.complex128) - for block_i in ProgressBar( - range(self.n_steps), mesg="frequency blocks"): - freqs = self._get_block_indices(block_i, self.n_freqs) - T[:, freqs] = np.array(parallel(parallel_compute_t( - C_r[:, f], T[:, f], n_seeds) for f in freqs) - ).transpose(1, 0, 2, 3) - - if not np.isreal(T).all() or not np.isfinite(T).all(): - raise RuntimeError( - 'the transformation matrix of the data must be real-valued ' - 'and contain no NaN or infinity values; check that you are ' - 'using full rank data or specify an appropriate rank for the ' - 'seeds and targets that is less than or equal to their ranks') - T = np.real(T) # make T real if check passes - - # Eq. 4 - D = np.matmul(T, np.matmul(csd, T)) - - # E as imag. part of D between seeds and targets - return np.imag(D[..., :n_seeds, n_seeds:]) - - def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, - U_bar_bb, con_i): - """Compute MIC and the associated spatial patterns.""" - n_seeds = len(seed_idcs) - n_targets = len(target_idcs) - times = np.arange(n_times) - freqs = np.arange(self.n_freqs) - - # Eigendecomp. to find spatial filters for seeds and targets - w_seeds, V_seeds = np.linalg.eigh( - np.matmul(E, E.transpose(0, 1, 3, 2))) - w_targets, V_targets = np.linalg.eigh( - np.matmul(E.transpose(0, 1, 3, 2), E)) - if ( - len(seed_idcs) == len(target_idcs) and - np.all(np.sort(seed_idcs) == np.sort(target_idcs)) - ): - # strange edge-case where the eigenvectors returned should be a set - # of identity matrices with one rotated by 90 degrees, but are - # instead identical (i.e. are not rotated versions of one another). - # This leads to the case where the spatial filters are incorrectly - # applied, resulting in connectivity estimates of ~0 when they - # should be perfectly correlated ~1. Accordingly, we manually - # create a set of rotated identity matrices to use as the filters. - create_filter = False - stop = False - while not create_filter and not stop: - for time_i in range(n_times): - for freq_i in range(self.n_freqs): - if np.all(V_seeds[time_i, freq_i] == - V_targets[time_i, freq_i]): - create_filter = True - break - stop = True - if create_filter: - n_chans = E.shape[2] - eye_4d = np.zeros_like(V_seeds) - eye_4d[:, :, np.arange(n_chans), np.arange(n_chans)] = 1 - V_seeds = eye_4d - V_targets = np.rot90(eye_4d, axes=(2, 3)) - - # Spatial filters with largest eigval. for seeds and targets - alpha = V_seeds[times[:, None], freqs, :, w_seeds.argmax(axis=2)] - beta = V_targets[times[:, None], freqs, :, w_targets.argmax(axis=2)] - - # Eq. 46 (seed spatial patterns) - self.patterns[0, con_i, :n_seeds] = (np.matmul( - np.real(C[..., :n_seeds, :n_seeds]), - np.matmul(U_bar_aa, np.expand_dims(alpha, axis=3))))[..., 0].T - - # Eq. 47 (target spatial patterns) - self.patterns[1, con_i, :n_targets] = (np.matmul( - np.real(C[..., n_seeds:, n_seeds:]), - np.matmul(U_bar_bb, np.expand_dims(beta, axis=3))))[..., 0].T - - # Eq. 7 - self.con_scores[con_i] = (np.einsum( - 'ijk,ijk->ij', alpha, np.matmul(E, np.expand_dims( - beta, axis=3))[..., 0] - ) / np.linalg.norm(alpha, axis=2) * np.linalg.norm(beta, axis=2)).T - - def _compute_mim(self, E, seed_idcs, target_idcs, con_i): - """Compute MIM (a.k.a. GIM if seeds == targets).""" - # Eq. 14 - self.con_scores[con_i] = np.matmul( - E, E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T - - # Eq. 15 - if ( - len(seed_idcs) == len(target_idcs) and - np.all(np.sort(seed_idcs) == np.sort(target_idcs)) - ): - self.con_scores[con_i] *= 0.5 - - def reshape_results(self): - """Remove time dimension from results, if necessary.""" - if self.n_times == 0: - self.con_scores = self.con_scores[..., 0] - if self.patterns is not None: - self.patterns = self.patterns[..., 0] - - -def _mic_mim_compute_t(C, T, n_seeds): - """Compute T for a single frequency (used for MIC and MIM).""" - for time_i in range(C.shape[0]): - T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power( - C[time_i, :n_seeds, :n_seeds], -0.5 - ) - T[time_i, n_seeds:, n_seeds:] = sp.linalg.fractional_matrix_power( - C[time_i, n_seeds:, n_seeds:], -0.5 - ) - - return T - - -class _MICEst(_MultivariateCohEstBase): - """Multivariate imaginary part of coherency (MIC) estimator.""" - - name = "MIC" - - -class _MIMEst(_MultivariateCohEstBase): - """Multivariate interaction measure (MIM) estimator.""" - - name = "MIM" - - -class _PLVEst(_EpochMeanConEstBase): - """PLV Estimator.""" - - name = 'PLV' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_PLVEst, self).__init__(n_cons, n_freqs, n_times) - - # allocate accumulator - self._acc = np.zeros(self.csd_shape, dtype=np.complex128) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - self._acc[con_idx] += csd_xy / np.abs(csd_xy) - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - plv = np.abs(self._acc / n_epochs) - self.con_scores[con_idx] = plv - - -class _ciPLVEst(_EpochMeanConEstBase): - """corrected imaginary PLV Estimator.""" - - name = 'ciPLV' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_ciPLVEst, self).__init__(n_cons, n_freqs, n_times) - - # allocate accumulator - self._acc = np.zeros(self.csd_shape, dtype=np.complex128) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - self._acc[con_idx] += csd_xy / np.abs(csd_xy) - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - imag_plv = np.abs(np.imag(self._acc)) / n_epochs - real_plv = np.real(self._acc) / n_epochs - real_plv = np.clip(real_plv, -1, 1) # bounded from -1 to 1 - mask = (np.abs(real_plv) == 1) # avoid division by 0 - real_plv[mask] = 0 - corrected_imag_plv = imag_plv / np.sqrt(1 - real_plv ** 2) - self.con_scores[con_idx] = corrected_imag_plv - - -class _PLIEst(_EpochMeanConEstBase): - """PLI Estimator.""" - - name = 'PLI' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_PLIEst, self).__init__(n_cons, n_freqs, n_times) - - # allocate accumulator - self._acc = np.zeros(self.csd_shape) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - self._acc[con_idx] += np.sign(np.imag(csd_xy)) - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - pli_mean = self._acc[con_idx] / n_epochs - self.con_scores[con_idx] = np.abs(pli_mean) - - -class _PLIUnbiasedEst(_PLIEst): - """Unbiased PLI Square Estimator.""" - - name = 'Unbiased PLI Square' - accumulate_psd = False - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - pli_mean = self._acc[con_idx] / n_epochs - - # See Vinck paper Eq. (30) - con = (n_epochs * pli_mean ** 2 - 1) / (n_epochs - 1) - - self.con_scores[con_idx] = con - - -class _DPLIEst(_EpochMeanConEstBase): - """DPLI Estimator.""" - - name = 'DPLI' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_DPLIEst, self).__init__(n_cons, n_freqs, n_times) - - # allocate accumulator - self._acc = np.zeros(self.csd_shape) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - self._acc[con_idx] += np.heaviside(np.imag(csd_xy), 0.5) - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - - con = self._acc[con_idx] / n_epochs - - self.con_scores[con_idx] = con - - -class _WPLIEst(_EpochMeanConEstBase): - """WPLI Estimator.""" - - name = 'WPLI' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_WPLIEst, self).__init__(n_cons, n_freqs, n_times) - - # store both imag(csd) and abs(imag(csd)) - acc_shape = (2,) + self.csd_shape - self._acc = np.zeros(acc_shape) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - im_csd = np.imag(csd_xy) - self._acc[0, con_idx] += im_csd - self._acc[1, con_idx] += np.abs(im_csd) - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - - num = np.abs(self._acc[0, con_idx]) - denom = self._acc[1, con_idx] - - # handle zeros in denominator - z_denom = np.where(denom == 0.) - denom[z_denom] = 1. - - con = num / denom - - # where we had zeros in denominator, we set con to zero - con[z_denom] = 0. - - self.con_scores[con_idx] = con - - -class _WPLIDebiasedEst(_EpochMeanConEstBase): - """Debiased WPLI Square Estimator.""" - - name = 'Debiased WPLI Square' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_WPLIDebiasedEst, self).__init__(n_cons, n_freqs, n_times) - # store imag(csd), abs(imag(csd)), imag(csd)^2 - acc_shape = (3,) + self.csd_shape - self._acc = np.zeros(acc_shape) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - im_csd = np.imag(csd_xy) - self._acc[0, con_idx] += im_csd - self._acc[1, con_idx] += np.abs(im_csd) - self._acc[2, con_idx] += im_csd ** 2 - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - - # note: we use the trick from fieldtrip to compute the - # the estimate over all pairwise epoch combinations - sum_im_csd = self._acc[0, con_idx] - sum_abs_im_csd = self._acc[1, con_idx] - sum_sq_im_csd = self._acc[2, con_idx] - - denom = sum_abs_im_csd ** 2 - sum_sq_im_csd - - # handle zeros in denominator - z_denom = np.where(denom == 0.) - denom[z_denom] = 1. - - con = (sum_im_csd ** 2 - sum_sq_im_csd) / denom - - # where we had zeros in denominator, we set con to zero - con[z_denom] = 0. - - self.con_scores[con_idx] = con - - -class _PPCEst(_EpochMeanConEstBase): - """Pairwise Phase Consistency (PPC) Estimator.""" - - name = 'PPC' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_PPCEst, self).__init__(n_cons, n_freqs, n_times) - - # store csd / abs(csd) - self._acc = np.zeros(self.csd_shape, dtype=np.complex128) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - denom = np.abs(csd_xy) - z_denom = np.where(denom == 0.) - denom[z_denom] = 1. - this_acc = csd_xy / denom - this_acc[z_denom] = 0. # handle division by zero - - self._acc[con_idx] += this_acc - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - - # note: we use the trick from fieldtrip to compute the - # the estimate over all pairwise epoch combinations - con = ((self._acc[con_idx] * np.conj(self._acc[con_idx]) - n_epochs) / - (n_epochs * (n_epochs - 1.))) - - self.con_scores[con_idx] = np.real(con) - - -class _GCEstBase(_EpochMeanMultivariateConEstBase): - """Base multivariate state-space Granger causality estimator.""" - - accumulate_psd = False - - def __init__(self, n_signals, n_cons, n_freqs, n_times, n_lags, n_jobs=1): - super(_GCEstBase, self).__init__( - n_signals, n_cons, n_freqs, n_times, n_jobs) - - self.freq_res = (self.n_freqs - 1) * 2 - if n_lags >= self.freq_res: - raise ValueError( - 'the number of lags (%i) must be less than double the ' - 'frequency resolution (%i)' % (n_lags, self.freq_res, )) - self.n_lags = n_lags - - def compute_con(self, indices, ranks, n_epochs=1): - """Compute multivariate state-space Granger causality.""" - assert self.name in ['GC', 'GC time-reversed'], ( - 'the class name is not recognised, please contact the ' - 'mne-connectivity developers') - - csd = self.reshape_csd() / n_epochs - - n_times = csd.shape[0] - times = np.arange(n_times) - freqs = np.arange(self.n_freqs) - - con_i = 0 - for seed_idcs, target_idcs, seed_rank, target_rank in zip( - indices[0], indices[1], ranks[0], ranks[1]): - self._log_connection_number(con_i) - - seed_idcs = seed_idcs[seed_idcs != -1] - target_idcs = target_idcs[target_idcs != -1] - con_idcs = [*seed_idcs, *target_idcs] - - C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] - - C_bar = self._csd_svd(C, seed_idcs, seed_rank, target_rank) - n_signals = seed_rank + target_rank - con_seeds = np.arange(seed_rank) - con_targets = np.arange(target_rank) + seed_rank - - autocov = self._compute_autocov(C_bar) - if self.name == "GC time-reversed": - autocov = autocov.transpose(0, 1, 3, 2) - - A_f, V = self._autocov_to_full_var(autocov) - A_f_3d = np.reshape( - A_f, (n_times, n_signals, n_signals * self.n_lags), - order="F") - A, K = self._full_var_to_iss(A_f_3d) - - self.con_scores[con_i] = self._iss_to_ugc( - A, A_f_3d, K, V, con_seeds, con_targets) - - con_i += 1 - - self.reshape_results() - - def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): - """Dimensionality reduction of CSD with SVD on the covariance.""" - # sum over times and epochs to get cov. from CSD - cov = csd.sum(axis=(0, 1)) - - n_seeds = len(seed_idcs) - n_targets = csd.shape[3] - n_seeds - - cov_aa = cov[:n_seeds, :n_seeds] - cov_bb = cov[n_seeds:, n_seeds:] - - if seed_rank != n_seeds: - U_aa = np.linalg.svd(np.real(cov_aa), full_matrices=False)[0] - U_bar_aa = U_aa[:, :seed_rank] - else: - U_bar_aa = np.identity(n_seeds) - - if target_rank != n_targets: - U_bb = np.linalg.svd(np.real(cov_bb), full_matrices=False)[0] - U_bar_bb = U_bb[:, :target_rank] - else: - U_bar_bb = np.identity(n_targets) - - C_aa = csd[..., :n_seeds, :n_seeds] - C_ab = csd[..., :n_seeds, n_seeds:] - C_bb = csd[..., n_seeds:, n_seeds:] - C_ba = csd[..., n_seeds:, :n_seeds] - - C_bar_aa = np.matmul( - U_bar_aa.transpose(1, 0), np.matmul(C_aa, U_bar_aa)) - C_bar_ab = np.matmul( - U_bar_aa.transpose(1, 0), np.matmul(C_ab, U_bar_bb)) - C_bar_bb = np.matmul( - U_bar_bb.transpose(1, 0), np.matmul(C_bb, U_bar_bb)) - C_bar_ba = np.matmul( - U_bar_bb.transpose(1, 0), np.matmul(C_ba, U_bar_aa)) - C_bar = np.append(np.append(C_bar_aa, C_bar_ab, axis=3), - np.append(C_bar_ba, C_bar_bb, axis=3), axis=2) - - return C_bar - - def _compute_autocov(self, csd): - """Compute autocovariance from the CSD.""" - n_times = csd.shape[0] - n_signals = csd.shape[2] - - circular_shifted_csd = np.concatenate( - [np.flip(np.conj(csd[:, 1:]), axis=1), csd[:, :-1]], axis=1) - ifft_shifted_csd = self._block_ifft( - circular_shifted_csd, self.freq_res) - lags_ifft_shifted_csd = np.reshape( - ifft_shifted_csd[:, :self.n_lags + 1], - (n_times, self.n_lags + 1, n_signals ** 2), order="F") - - signs = np.repeat([1], self.n_lags + 1).tolist() - signs[1::2] = [x * -1 for x in signs[1::2]] - sign_matrix = np.repeat( - np.tile(np.array(signs), (n_signals ** 2, 1))[np.newaxis], - n_times, axis=0).transpose(0, 2, 1) - - return np.real(np.reshape( - sign_matrix * lags_ifft_shifted_csd, - (n_times, self.n_lags + 1, n_signals, n_signals), order="F")) - - def _block_ifft(self, csd, n_points): - """Compute block iFFT with n points.""" - shape = csd.shape - csd_3d = np.reshape( - csd, (shape[0], shape[1], shape[2] * shape[3]), order="F") - - csd_ifft = np.fft.ifft(csd_3d, n=n_points, axis=1) - - return np.reshape(csd_ifft, shape, order="F") - - def _autocov_to_full_var(self, autocov): - """Compute full VAR model using Whittle's LWR recursion.""" - if np.any(np.linalg.det(autocov) == 0): - raise RuntimeError( - 'the autocovariance matrix is singular; check if your data is ' - 'rank deficient and specify an appropriate rank argument <= ' - 'the rank of the seeds and targets') - - A_f, V = self._whittle_lwr_recursion(autocov) - - if not np.isfinite(A_f).all(): - raise RuntimeError('at least one VAR model coefficient is ' - 'infinite or NaN; check the data you are using') - - try: - np.linalg.cholesky(V) - except np.linalg.LinAlgError as np_error: - raise RuntimeError( - 'the covariance matrix of the residuals is not ' - 'positive-definite; check the singular values of your data ' - 'and specify an appropriate rank argument <= the rank of the ' - 'seeds and targets') from np_error - - return A_f, V - - def _whittle_lwr_recursion(self, G): - """Solve Yule-Walker eqs. for full VAR params. with LWR recursion. - - See: Whittle P., 1963. Biometrika, DOI: 10.1093/biomet/50.1-2.129 - """ - # Initialise recursion - n = G.shape[2] # number of signals - q = G.shape[1] - 1 # number of lags - t = G.shape[0] # number of times - qn = n * q - - cov = G[:, 0, :, :] # covariance - G_f = np.reshape( - G[:, 1:, :, :].transpose(0, 3, 1, 2), (t, qn, n), - order="F") # forward autocov - G_b = np.reshape( - np.flip(G[:, 1:, :, :], 1).transpose(0, 3, 2, 1), (t, n, qn), - order="F").transpose(0, 2, 1) # backward autocov - - A_f = np.zeros((t, n, qn)) # forward coefficients - A_b = np.zeros((t, n, qn)) # backward coefficients - - k = 1 # model order - r = q - k - k_f = np.arange(k * n) # forward indices - k_b = np.arange(r * n, qn) # backward indices - - try: - A_f[:, :, k_f] = np.linalg.solve( - cov, G_b[:, k_b, :].transpose(0, 2, 1)).transpose(0, 2, 1) - A_b[:, :, k_b] = np.linalg.solve( - cov, G_f[:, k_f, :].transpose(0, 2, 1)).transpose(0, 2, 1) - - # Perform recursion - for k in np.arange(2, q + 1): - var_A = (G_b[:, (r - 1) * n: r * n, :] - - np.matmul(A_f[:, :, k_f], G_b[:, k_b, :])) - var_B = cov - np.matmul(A_b[:, :, k_b], G_b[:, k_b, :]) - AA_f = np.linalg.solve( - var_B, var_A.transpose(0, 2, 1)).transpose(0, 2, 1) - - var_A = (G_f[:, (k - 1) * n: k * n, :] - - np.matmul(A_b[:, :, k_b], G_f[:, k_f, :])) - var_B = cov - np.matmul(A_f[:, :, k_f], G_f[:, k_f, :]) - AA_b = np.linalg.solve( - var_B, var_A.transpose(0, 2, 1)).transpose(0, 2, 1) - - A_f_previous = A_f[:, :, k_f] - A_b_previous = A_b[:, :, k_b] - - r = q - k - k_f = np.arange(k * n) - k_b = np.arange(r * n, qn) - - A_f[:, :, k_f] = np.dstack( - (A_f_previous - np.matmul(AA_f, A_b_previous), AA_f)) - A_b[:, :, k_b] = np.dstack( - (AA_b, A_b_previous - np.matmul(AA_b, A_f_previous))) - except np.linalg.LinAlgError as np_error: - raise RuntimeError( - 'the autocovariance matrix is singular; check if your data is ' - 'rank deficient and specify an appropriate rank argument <= ' - 'the rank of the seeds and targets') from np_error - - V = cov - np.matmul(A_f, G_f) - A_f = np.reshape(A_f, (t, n, n, q), order="F") - - return A_f, V - - def _full_var_to_iss(self, A_f): - """Compute innovations-form parameters for a state-space model. - - Parameters computed from a full VAR model using Aoki's method. For a - non-moving-average full VAR model, the state-space parameter C - (observation matrix) is identical to AF of the VAR model. - - See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: - 10.1103/PhysRevE.91.040101. - """ - t = A_f.shape[0] - m = A_f.shape[1] # number of signals - p = A_f.shape[2] // m # number of autoregressive lags - - I_p = np.dstack(t * [np.eye(m * p)]).transpose(2, 0, 1) - A = np.hstack((A_f, I_p[:, : (m * p - m), :])) # state transition - # matrix - K = np.hstack(( - np.dstack(t * [np.eye(m)]).transpose(2, 0, 1), - np.zeros((t, (m * (p - 1)), m)))) # Kalman gain matrix - - return A, K - - def _iss_to_ugc(self, A, C, K, V, seeds, targets): - """Compute unconditional GC from innovations-form state-space params. - - See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: - 10.1103/PhysRevE.91.040101. - """ - times = np.arange(A.shape[0]) - freqs = np.arange(self.n_freqs) - z = np.exp(-1j * np.pi * np.linspace(0, 1, self.n_freqs)) # points - # on a unit circle in the complex plane, one for each frequency - - H = self._iss_to_tf(A, C, K, z) # spectral transfer function - V_22_1 = np.linalg.cholesky(self._partial_covar(V, seeds, targets)) - HV = np.matmul(H, np.linalg.cholesky(V)) - S = np.matmul(HV, HV.conj().transpose(0, 1, 3, 2)) # Eq. 6 - S_11 = S[np.ix_(freqs, times, targets, targets)] - HV_12 = np.matmul(H[np.ix_(freqs, times, targets, seeds)], V_22_1) - HVH = np.matmul(HV_12, HV_12.conj().transpose(0, 1, 3, 2)) - - # Eq. 11 - return np.real( - np.log(np.linalg.det(S_11)) - np.log(np.linalg.det(S_11 - HVH))) - - def _iss_to_tf(self, A, C, K, z): - """Compute transfer function for innovations-form state-space params. - - In the frequency domain, the back-shift operator, z, is a vector of - points on a unit circle in the complex plane. z = e^-iw, where -pi < w - <= pi. - - A note on efficiency: solving over the 4D time-freq. tensor is slower - than looping over times and freqs when n_times and n_freqs high, and - when n_times and n_freqs low, looping over times and freqs very fast - anyway (plus tensor solving doesn't allow for parallelisation). - - See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: - 10.1103/PhysRevE.91.040101. - """ - t = A.shape[0] - h = self.n_freqs - n = C.shape[1] - m = A.shape[1] - I_n = np.eye(n) - I_m = np.eye(m) - H = np.zeros((h, t, n, n), dtype=np.complex128) - - parallel, parallel_compute_H, _ = parallel_func( - _gc_compute_H, self.n_jobs, verbose=False - ) - H = np.zeros((h, t, n, n), dtype=np.complex128) - for block_i in ProgressBar( - range(self.n_steps), mesg="frequency blocks" - ): - freqs = self._get_block_indices(block_i, self.n_freqs) - H[freqs] = parallel( - parallel_compute_H(A, C, K, z[k], I_n, I_m) for k in freqs) - - return H - - def _partial_covar(self, V, seeds, targets): - """Compute partial covariance of a matrix. - - Given a covariance matrix V, the partial covariance matrix of V between - indices i and j, given k (V_ij|k), is equivalent to V_ij - V_ik * - V_kk^-1 * V_kj. In this case, i and j are seeds, and k are targets. - - See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: - 10.1103/PhysRevE.91.040101. - """ - times = np.arange(V.shape[0]) - W = np.linalg.solve( - np.linalg.cholesky(V[np.ix_(times, targets, targets)]), - V[np.ix_(times, targets, seeds)], - ) - W = np.matmul(W.transpose(0, 2, 1), W) - - return V[np.ix_(times, seeds, seeds)] - W - - def reshape_results(self): - """Remove time dimension from con. scores, if necessary.""" - if self.n_times == 0: - self.con_scores = self.con_scores[:, :, 0] - - -def _gc_compute_H(A, C, K, z_k, I_n, I_m): - """Compute transfer function for innovations-form state-space params. - - See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: - 10.1103/PhysRevE.91.040101, Eq. 4. - """ - from scipy import linalg # XXX: is this necessary??? - H = np.zeros((A.shape[0], C.shape[1], C.shape[1]), dtype=np.complex128) - for t in range(A.shape[0]): - H[t] = I_n + np.matmul( - C[t], linalg.lu_solve(linalg.lu_factor(z_k * I_m - A[t]), K[t])) - - return H - - -class _GCEst(_GCEstBase): - """[seeds -> targets] state-space GC estimator.""" - - name = "GC" - - -class _GCTREst(_GCEstBase): - """time-reversed[seeds -> targets] state-space GC estimator.""" - - name = "GC time-reversed" - ############################################################################### -_multivariate_methods = ['mic', 'mim', 'gc', 'gc_tr'] _gc_methods = ['gc', 'gc_tr'] @@ -1292,9 +254,9 @@ def _epoch_spectral_connectivity(data, sig_idx, tmin_idx, tmax_idx, sfreq, block_size, psd, accumulate_psd, con_method_types, con_methods, n_signals, n_signals_use, n_times, gc_n_lags, - accumulate_inplace=True): + multivariate_con, accumulate_inplace=True): """Estimate connectivity for one epoch (see spectral_connectivity).""" - if any(this_method in _multivariate_methods for this_method in method): + if multivariate_con: n_con_signals = n_signals_use ** 2 else: n_con_signals = n_cons @@ -1311,8 +273,7 @@ def _epoch_spectral_connectivity(data, sig_idx, tmin_idx, tmax_idx, sfreq, con_methods = [] for mtype in con_method_types: method_params = list(inspect.signature(mtype).parameters) - if "n_signals" in method_params: - # if it's a multivariate connectivity method + if multivariate_con: if "n_lags" in method_params: # if it's a Granger causality method con_methods.append( @@ -1501,22 +462,12 @@ def _get_and_verify_data_sizes(data, sfreq, n_signals=None, n_times=None, return n_signals, n_times, times, warn_times -# map names to estimator types -_CON_METHOD_MAP = {'coh': _CohEst, 'cohy': _CohyEst, 'imcoh': _ImCohEst, - 'plv': _PLVEst, 'ciplv': _ciPLVEst, 'ppc': _PPCEst, - 'pli': _PLIEst, 'pli2_unbiased': _PLIUnbiasedEst, - 'dpli': _DPLIEst, 'wpli': _WPLIEst, - 'wpli2_debiased': _WPLIDebiasedEst, 'mic': _MICEst, - 'mim': _MIMEst, 'gc': _GCEst, 'gc_tr': _GCTREst} - - -def _check_estimators(method): +def _check_estimators(method, con_method_map): """Check construction of connectivity estimators.""" - n_methods = len(method) con_method_types = list() for this_method in method: - if this_method in _CON_METHOD_MAP: - con_method_types.append(_CON_METHOD_MAP[this_method]) + if this_method in con_method_map: + con_method_types.append(con_method_map[this_method]) elif isinstance(this_method, str): raise ValueError('%s is not a valid connectivity method' % this_method) @@ -1532,290 +483,18 @@ def _check_estimators(method): accumulate_psd = any( this_method.accumulate_psd for this_method in con_method_types) - return con_method_types, n_methods, accumulate_psd - - -@ verbose -@ fill_doc -def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, - sfreq=None, - mode='multitaper', fmin=None, fmax=np.inf, - fskip=0, faverage=False, tmin=None, tmax=None, - mt_bandwidth=None, mt_adaptive=False, - mt_low_bias=True, cwt_freqs=None, - cwt_n_cycles=7, gc_n_lags=40, rank=None, - block_size=1000, n_jobs=1, verbose=None): - r"""Compute frequency- and time-frequency-domain connectivity measures. - - The connectivity method(s) are specified using the "method" parameter. - All methods are based on estimates of the cross- and power spectral - densities (CSD/PSD) Sxy and Sxx, Syy. - - Parameters - ---------- - data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs - The data from which to compute connectivity. Note that it is also - possible to combine multiple signals by providing a list of tuples, - e.g., data = [(arr_0, stc_0), (arr_1, stc_1), (arr_2, stc_2)], - corresponds to 3 epochs, and arr_* could be an array with the same - number of time points as stc_*. The array-like object can also - be a list/generator of array, shape =(n_signals, n_times), - or a list/generator of SourceEstimate or VolSourceEstimate objects. - %(names)s - method : str | list of str - Connectivity measure(s) to compute. These can be ``['coh', 'cohy', - 'imcoh', 'mic', 'mim', 'plv', 'ciplv', 'ppc', 'pli', 'dpli', 'wpli', - 'wpli2_debiased', 'gc', 'gc_tr']``. Multivariate methods (``['mic', - 'mim', 'gc', 'gc_tr]``) cannot be called with the other methods. - indices : tuple of array | None - Two arrays with indices of connections for which to compute - connectivity. If a bivariate method is called, each array for the seeds - and targets should contain the channel indices for each bivariate - connection. If a multivariate method is called, each array for the - seeds and targets should consist of nested arrays containing - the channel indices for each multivariate connection. If ``None``, - connections between all channels are computed, unless a Granger - causality method is called, in which case an error is raised. - sfreq : float - The sampling frequency. Required if data is not - :class:`Epochs `. - mode : str - Spectrum estimation mode can be either: 'multitaper', 'fourier', or - 'cwt_morlet'. - fmin : float | tuple of float - The lower frequency of interest. Multiple bands are defined using - a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. - fmax : float | tuple of float - The upper frequency of interest. Multiple bands are dedined using - a tuple, e.g. (13., 30.) for two band with 13Hz and 30Hz upper freq. - fskip : int - Omit every "(fskip + 1)-th" frequency bin to decimate in frequency - domain. - faverage : bool - Average connectivity scores for each frequency band. If True, - the output freqs will be a list with arrays of the frequencies - that were averaged. - tmin : float | None - Time to start connectivity estimation. Note: when "data" is an array, - the first sample is assumed to be at time 0. For other types - (Epochs, etc.), the time information contained in the object is used - to compute the time indices. - tmax : float | None - Time to end connectivity estimation. Note: when "data" is an array, - the first sample is assumed to be at time 0. For other types - (Epochs, etc.), the time information contained in the object is used - to compute the time indices. - mt_bandwidth : float | None - The bandwidth of the multitaper windowing function in Hz. - Only used in 'multitaper' mode. - mt_adaptive : bool - Use adaptive weights to combine the tapered spectra into PSD. - Only used in 'multitaper' mode. - mt_low_bias : bool - Only use tapers with more than 90 percent spectral concentration - within bandwidth. Only used in 'multitaper' mode. - cwt_freqs : array - Array of frequencies of interest. Only used in 'cwt_morlet' mode. - cwt_n_cycles : float | array of float - Number of cycles. Fixed number or one per frequency. Only used in - 'cwt_morlet' mode. - gc_n_lags : int - Number of lags to use for the vector autoregressive model when - computing Granger causality. Higher values increase computational cost, - but reduce the degree of spectral smoothing in the results. Only used - if ``method`` contains any of ``['gc', 'gc_tr']``. - rank : tuple of array | None - Two arrays with the rank to project the seed and target data to, - respectively, using singular value decomposition. If None, the rank of - the data is computed and projected to. Only used if ``method`` contains - any of ``['mic', 'mim', 'gc', 'gc_tr']``. - block_size : int - How many connections to compute at once (higher numbers are faster - but require more memory). - n_jobs : int - How many samples to process in parallel. - %(verbose)s - - Returns - ------- - con : array | list of array - Computed connectivity measure(s). Either an instance of - ``SpectralConnectivity`` or ``SpectroTemporalConnectivity``. - The shape of the connectivity result will be: - - - ``(n_cons, n_freqs)`` for multitaper or fourier modes - - ``(n_cons, n_freqs, n_times)`` for cwt_morlet mode - - ``n_cons = n_signals ** 2`` for bivariate methods with - ``indices=None`` - - ``n_cons = 1`` for multivariate methods with ``indices=None`` - - ``n_cons = len(indices[0])`` for bivariate and multivariate methods - when indices is supplied. - - See Also - -------- - mne_connectivity.spectral_connectivity_time - mne_connectivity.SpectralConnectivity - mne_connectivity.SpectroTemporalConnectivity - - Notes - ----- - Please note that the interpretation of the measures in this function - depends on the data and underlying assumptions and does not necessarily - reflect a causal relationship between brain regions. - - These measures are not to be interpreted over time. Each Epoch passed into - the dataset is interpreted as an independent sample of the same - connectivity structure. Within each Epoch, it is assumed that the spectral - measure is stationary. The spectral measures implemented in this function - are computed across Epochs. **Thus, spectral measures computed with only - one Epoch will result in errorful values and spectral measures computed - with few Epochs will be unreliable.** Please see - ``spectral_connectivity_time`` for time-resolved connectivity estimation. - - The spectral densities can be estimated using a multitaper method with - digital prolate spheroidal sequence (DPSS) windows, a discrete Fourier - transform with Hanning windows, or a continuous wavelet transform using - Morlet wavelets. The spectral estimation mode is specified using the - "mode" parameter. - - By default, the connectivity between all signals is computed (only - connections corresponding to the lower-triangular part of the connectivity - matrix). If one is only interested in the connectivity between some - signals, the "indices" parameter can be used. For example, to compute the - connectivity between the signal with index 0 and signals "2, 3, 4" (a total - of 3 connections) one can use the following:: - - indices = (np.array([0, 0, 0]), # row indices - np.array([2, 3, 4])) # col indices - - con = spectral_connectivity_epochs(data, method='coh', - indices=indices, ...) - - In this case con.get_data().shape = (3, n_freqs). The connectivity scores - are in the same order as defined indices. - - For multivariate methods, this is handled differently. If "indices" is - None, connectivity between all signals will be computed and a single - connectivity spectrum will be returned (this is not possible if a Granger - causality method is called). If "indices" is specified, seed and target - indices for each connection should be specified as nested array-likes. For - example, to compute the connectivity between signals (0, 1) -> (2, 3) and - (0, 1) -> (4, 5), indices should be specified as:: - - indices = (np.array([[0, 1], [0, 1]]), # seeds - np.array([[2, 3], [4, 5]])) # targets - - More information on working with multivariate indices and handling - connections where the number of seeds and targets are not equal can be - found in the :doc:`../auto_examples/handling_ragged_arrays` example. - - **Supported Connectivity Measures** - - The connectivity method(s) is specified using the "method" parameter. The - following methods are supported (note: ``E[]`` denotes average over - epochs). Multiple measures can be computed at once by using a list/tuple, - e.g., ``['coh', 'pli']`` to compute coherence and PLI. - - 'coh' : Coherence given by:: - - | E[Sxy] | - C = --------------------- - sqrt(E[Sxx] * E[Syy]) - - 'cohy' : Coherency given by:: - - E[Sxy] - C = --------------------- - sqrt(E[Sxx] * E[Syy]) - - 'imcoh' : Imaginary coherence :footcite:`NolteEtAl2004` given by:: - - Im(E[Sxy]) - C = ---------------------- - sqrt(E[Sxx] * E[Syy]) - - 'mic' : Maximised Imaginary part of Coherency (MIC) - :footcite:`EwaldEtAl2012` given by: - - :math:`MIC=\Large{\frac{\boldsymbol{\alpha}^T \boldsymbol{E \beta}} - {\parallel\boldsymbol{\alpha}\parallel \parallel\boldsymbol{\beta} - \parallel}}` - - where: :math:`\boldsymbol{E}` is the imaginary part of the - transformed cross-spectral density between seeds and targets; and - :math:`\boldsymbol{\alpha}` and :math:`\boldsymbol{\beta}` are - eigenvectors for the seeds and targets, such that - :math:`\boldsymbol{\alpha}^T \boldsymbol{E \beta}` maximises - connectivity between the seeds and targets. - - 'mim' : Multivariate Interaction Measure (MIM) - :footcite:`EwaldEtAl2012` given by: - - :math:`MIM=tr(\boldsymbol{EE}^T)` - - 'plv' : Phase-Locking Value (PLV) :footcite:`LachauxEtAl1999` given - by:: - - PLV = |E[Sxy/|Sxy|]| - - 'ciplv' : corrected imaginary PLV (ciPLV) - :footcite:`BrunaEtAl2018` given by:: - - |E[Im(Sxy/|Sxy|)]| - ciPLV = ------------------------------------ - sqrt(1 - |E[real(Sxy/|Sxy|)]| ** 2) - - 'ppc' : Pairwise Phase Consistency (PPC), an unbiased estimator - of squared PLV :footcite:`VinckEtAl2010`. - - 'pli' : Phase Lag Index (PLI) :footcite:`StamEtAl2007` given by:: - - PLI = |E[sign(Im(Sxy))]| - - 'pli2_unbiased' : Unbiased estimator of squared PLI - :footcite:`VinckEtAl2011`. - - 'dpli' : Directed Phase Lag Index (DPLI) :footcite:`StamEtAl2012` - given by (where H is the Heaviside function):: + return con_method_types, accumulate_psd - DPLI = E[H(Im(Sxy))] - 'wpli' : Weighted Phase Lag Index (WPLI) :footcite:`VinckEtAl2011` - given by:: - - |E[Im(Sxy)]| - WPLI = ------------------ - E[|Im(Sxy)|] - - 'wpli2_debiased' : Debiased estimator of squared WPLI - :footcite:`VinckEtAl2011`. - - 'gc' : State-space Granger Causality (GC) :footcite:`BarnettSeth2015` - given by: - - :math:`GC = ln\Large{(\frac{\lvert\boldsymbol{S}_{tt}\rvert}{\lvert - \boldsymbol{S}_{tt}-\boldsymbol{H}_{ts}\boldsymbol{\Sigma}_{ss - \lvert t}\boldsymbol{H}_{ts}^*\rvert}})`, - - where: :math:`s` and :math:`t` represent the seeds and targets, - respectively; :math:`\boldsymbol{H}` is the spectral transfer - function; :math:`\boldsymbol{\Sigma}` is the residuals matrix of - the autoregressive model; and :math:`\boldsymbol{S}` is - :math:`\boldsymbol{\Sigma}` transformed by :math:`\boldsymbol{H}`. - - 'gc_tr' : State-space GC on time-reversed signals - :footcite:`BarnettSeth2015,WinklerEtAl2016` given by the same equation - as for 'gc', but where the autocovariance sequence from which the - autoregressive model is produced is transposed to mimic the reversal of - the original signal in time. - - References - ---------- - .. footbibliography:: - """ +def _check_spectral_connectivity_epochs_settings(method, fmin, fmax, n_jobs, + verbose, con_method_map): + """Check settings inputs for spectral_connectivity_epochs... functions.""" if n_jobs != 1: parallel, my_epoch_spectral_connectivity, _ = parallel_func( _epoch_spectral_connectivity, n_jobs, verbose=verbose) + else: + parallel = None + my_epoch_spectral_connectivity = None # format fmin and fmax and check inputs if fmin is None: @@ -1827,34 +506,22 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, raise ValueError('fmin and fmax must have the same length') if np.any(fmin > fmax): raise ValueError('fmax must be larger than fmin') - n_bands = len(fmin) # assign names to connectivity methods if not isinstance(method, (list, tuple)): method = [method] # make it a list so we can iterate over it - if n_bands != 1 and any( - this_method in _gc_methods for this_method in method - ): - raise ValueError('computing Granger causality on multiple frequency ' - 'bands is not yet supported') - - if any(this_method in _multivariate_methods for this_method in method): - if not all(this_method in _multivariate_methods for - this_method in method): - raise ValueError( - 'bivariate and multivariate connectivity methods cannot be ' - 'used in the same function call') - multivariate_con = True - else: - multivariate_con = False - # handle connectivity estimators - (con_method_types, n_methods, accumulate_psd) = _check_estimators(method) + con_method_types, accumulate_psd = _check_estimators(method, + con_method_map) + + return (fmin, fmax, n_bands, method, con_method_types, accumulate_psd, + parallel, my_epoch_spectral_connectivity) - events = None - event_id = None + +def _check_spectral_connectivity_epochs_data(data, sfreq, names): + """Check data inputs for spectral_connectivity_epochs... functions.""" if isinstance(data, BaseEpochs): names = data.ch_names times_in = data.times # input times for Epochs input type @@ -1876,208 +543,23 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, data.add_annotations_to_metadata(overwrite=True) metadata = data.metadata else: + events = None + event_id = None times_in = None metadata = None if sfreq is None: raise ValueError('Sampling frequency (sfreq) is required with ' 'array input.') - # loop over data; it could be a generator that returns - # (n_signals x n_times) arrays or SourceEstimates - epoch_idx = 0 - logger.info('Connectivity computation...') - warn_times = True - for epoch_block in _get_n_epochs(data, n_jobs): - if epoch_idx == 0: - # initialize everything times and frequencies - (n_cons, times, n_times, times_in, n_times_in, tmin_idx, - tmax_idx, n_freqs, freq_mask, freqs, freqs_bands, freq_idx_bands, - n_signals, indices_use, warn_times) = _prepare_connectivity( - epoch_block=epoch_block, times_in=times_in, - tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, sfreq=sfreq, - indices=indices, method=method, mode=mode, fskip=fskip, - n_bands=n_bands, cwt_freqs=cwt_freqs, faverage=faverage) - - # check rank input and compute data ranks if necessary - if multivariate_con: - rank = _check_rank_input(rank, data, indices_use) - else: - rank = None - gc_n_lags = None - - # make sure padded indices are stored in the connectivity object - if multivariate_con and indices is not None: - indices = tuple(np.array(indices_use)) # create a copy - - # get the window function, wavelets, etc for different modes - (spectral_params, mt_adaptive, n_times_spectrum, - n_tapers) = _assemble_spectral_params( - mode=mode, n_times=n_times, mt_adaptive=mt_adaptive, - mt_bandwidth=mt_bandwidth, sfreq=sfreq, - mt_low_bias=mt_low_bias, cwt_n_cycles=cwt_n_cycles, - cwt_freqs=cwt_freqs, freqs=freqs, freq_mask=freq_mask) - - # unique signals for which we actually need to compute PSD etc. - if multivariate_con: - sig_idx = np.unique(np.concatenate(np.concatenate( - indices_use))) - sig_idx = sig_idx[sig_idx != -1] - remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(sig_idx)} - remapping[-1] = -1 - remapped_inds = (indices_use[0].copy(), indices_use[1].copy()) - con_i = 0 - for seed, target in zip(indices_use[0], indices_use[1]): - remapped_inds[0][con_i] = np.array([ - remapping[idx] for idx in seed]) - remapped_inds[1][con_i] = np.array([ - remapping[idx] for idx in target]) - con_i += 1 - remapped_sig = [remapping[idx] for idx in sig_idx] - else: - sig_idx = np.unique(np.r_[indices_use[0], indices_use[1]]) - n_signals_use = len(sig_idx) - - # map indices to unique indices - if multivariate_con: - indices_use = remapped_inds # use remapped seeds & targets - idx_map = [np.sort(np.repeat(remapped_sig, len(sig_idx))), - np.tile(remapped_sig, len(sig_idx))] - else: - idx_map = [ - np.searchsorted(sig_idx, ind) for ind in indices_use] + return (names, times_in, sfreq, events, event_id, metadata) - # allocate space to accumulate PSD - if accumulate_psd: - if n_times_spectrum == 0: - psd_shape = (n_signals_use, n_freqs) - else: - psd_shape = (n_signals_use, n_freqs, n_times_spectrum) - psd = np.zeros(psd_shape) - else: - psd = None - - # create instances of the connectivity estimators - con_methods = [] - for mtype_i, mtype in enumerate(con_method_types): - method_params = dict(n_cons=n_cons, n_freqs=n_freqs, - n_times=n_times_spectrum) - if method[mtype_i] in _multivariate_methods: - method_params.update(dict(n_signals=n_signals_use)) - if method[mtype_i] in _gc_methods: - method_params.update(dict(n_lags=gc_n_lags)) - con_methods.append(mtype(**method_params)) - - sep = ', ' - metrics_str = sep.join([meth.name for meth in con_methods]) - logger.info(' the following metrics will be computed: %s' - % metrics_str) - - # check dimensions and time scale - for this_epoch in epoch_block: - _, _, _, warn_times = _get_and_verify_data_sizes( - this_epoch, sfreq, n_signals, n_times_in, times_in, - warn_times=warn_times) - - call_params = dict( - sig_idx=sig_idx, tmin_idx=tmin_idx, tmax_idx=tmax_idx, sfreq=sfreq, - method=method, mode=mode, freq_mask=freq_mask, idx_map=idx_map, - n_cons=n_cons, block_size=block_size, - psd=psd, accumulate_psd=accumulate_psd, - mt_adaptive=mt_adaptive, - con_method_types=con_method_types, - con_methods=con_methods if n_jobs == 1 else None, - n_signals=n_signals, n_signals_use=n_signals_use, n_times=n_times, - gc_n_lags=gc_n_lags, - accumulate_inplace=True if n_jobs == 1 else False) - call_params.update(**spectral_params) - - if n_jobs == 1: - # no parallel processing - for this_epoch in epoch_block: - logger.info(' computing cross-spectral density for epoch %d' - % (epoch_idx + 1)) - # con methods and psd are updated inplace - _epoch_spectral_connectivity(data=this_epoch, **call_params) - epoch_idx += 1 - else: - # process epochs in parallel - logger.info( - ' computing cross-spectral density for epochs %d..%d' - % (epoch_idx + 1, epoch_idx + len(epoch_block))) - - out = parallel(my_epoch_spectral_connectivity( - data=this_epoch, **call_params) - for this_epoch in epoch_block) - # do the accumulation - for this_out in out: - for _method, parallel_method in zip(con_methods, this_out[0]): - _method.combine(parallel_method) - if accumulate_psd: - psd += this_out[1] - - epoch_idx += len(epoch_block) - - # normalize - n_epochs = epoch_idx - if accumulate_psd: - psd /= n_epochs - - # compute final connectivity scores - con = list() - patterns = list() - for method_i, conn_method in enumerate(con_methods): - - # future estimators will need to be handled here - if conn_method.accumulate_psd: - # compute scores block-wise to save memory - for i in range(0, n_cons, block_size): - con_idx = slice(i, i + block_size) - psd_xx = psd[idx_map[0][con_idx]] - psd_yy = psd[idx_map[1][con_idx]] - conn_method.compute_con(con_idx, n_epochs, psd_xx, psd_yy) - else: - # compute all scores at once - if method[method_i] in _multivariate_methods: - conn_method.compute_con(indices_use, rank, n_epochs) - else: - conn_method.compute_con(slice(0, n_cons), n_epochs) - - # get the connectivity scores - this_con = conn_method.con_scores - this_patterns = conn_method.patterns - - if this_con.shape[0] != n_cons: - raise RuntimeError( - 'first dimension of connectivity scores does not match the ' - 'number of connections; please contact the mne-connectivity ' - 'developers') - if faverage: - if this_con.shape[1] != n_freqs: - raise RuntimeError( - 'second dimension of connectivity scores does not match ' - 'the number of frequencies; please contact the ' - 'mne-connectivity developers') - con_shape = (n_cons, n_bands) + this_con.shape[2:] - this_con_bands = np.empty(con_shape, dtype=this_con.dtype) - for band_idx in range(n_bands): - this_con_bands[:, band_idx] = np.mean( - this_con[:, freq_idx_bands[band_idx]], axis=1) - this_con = this_con_bands - - if this_patterns is not None: - patterns_shape = list(this_patterns.shape) - patterns_shape[3] = n_bands - this_patterns_bands = np.empty(patterns_shape, - dtype=this_patterns.dtype) - for band_idx in range(n_bands): - this_patterns_bands[:, :, :, band_idx] = np.mean( - this_patterns[:, :, :, freq_idx_bands[band_idx]], - axis=3) - this_patterns = this_patterns_bands - - con.append(this_con) - patterns.append(this_patterns) +def _store_results( + con, patterns, method, freqs, faverage, freqs_bands, names, mode, indices, + n_epochs, times, n_tapers, metadata, events, event_id, rank, gc_n_lags, + n_signals +): + """Store results in connectivity containers.""" freqs_used = freqs if faverage: # for each band we return the frequencies that were averaged @@ -2090,23 +572,6 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, freqs_used = freqs_bands freqs_used = [[np.min(band), np.max(band)] for band in freqs_used] - if indices is None and not multivariate_con: - # return all-to-all connectivity matrices - # raveled into a 1D array - logger.info(' assembling connectivity matrix') - con_flat = con - con = list() - for this_con_flat in con_flat: - this_con = np.zeros((n_signals, n_signals) + - this_con_flat.shape[1:], - dtype=this_con_flat.dtype) - this_con[indices_use] = this_con_flat - - # ravel 2D connectivity into a 1D array - # while keeping other dimensions - this_con = this_con.reshape((n_signals ** 2,) + - this_con_flat.shape[1:]) - con.append(this_con) # number of nodes in the original data n_nodes = n_signals @@ -2131,7 +596,7 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, logger.info('[Connectivity computation done]') - if n_methods == 1: + if len(method) == 1: # for a single method return connectivity directly conn_list = conn_list[0] diff --git a/mne_connectivity/spectral/epochs_bivariate.py b/mne_connectivity/spectral/epochs_bivariate.py new file mode 100644 index 00000000..044de3b4 --- /dev/null +++ b/mne_connectivity/spectral/epochs_bivariate.py @@ -0,0 +1,729 @@ +# Authors: Martin Luessi +# Denis A. Engemann +# Adam Li +# Thomas S. Binns +# +# License: BSD (3-clause) + +import numpy as np +from mne.utils import logger, verbose + +from .epochs import ( + _AbstractConEstBase, _check_spectral_connectivity_epochs_settings, + _check_spectral_connectivity_epochs_data, _get_n_epochs, + _prepare_connectivity, _assemble_spectral_params, + _compute_spectral_methods_epochs, _store_results) +from ..utils import fill_doc, check_indices + + +def _check_indices(indices, n_signals): + if indices is None: + logger.info('only using indices for lower-triangular matrix') + # only compute r for lower-triangular region + indices_use = np.tril_indices(n_signals, -1) + else: + indices_use = check_indices(indices) + + # number of connectivities to compute + n_cons = len(indices_use[0]) + logger.info(' computing connectivity for %d connections' % n_cons) + + return n_cons, indices_use + + +######################################################################## +# Bivariate connectivity estimators + + +class _EpochMeanConEstBase(_AbstractConEstBase): + """Base class for methods that estimate connectivity as mean epoch-wise.""" + + patterns = None + + def __init__(self, n_cons, n_freqs, n_times): + self.n_cons = n_cons + self.n_freqs = n_freqs + self.n_times = n_times + + if n_times == 0: + self.csd_shape = (n_cons, n_freqs) + else: + self.csd_shape = (n_cons, n_freqs, n_times) + + self.con_scores = None + + def start_epoch(self): # noqa: D401 + """Called at the start of each epoch.""" + pass # for this type of con. method we don't do anything + + def combine(self, other): + """Include con. accumated for some epochs in this estimate.""" + self._acc += other._acc + + +class _CohEstBase(_EpochMeanConEstBase): + """Base Estimator for Coherence, Coherency, Imag. Coherence.""" + + accumulate_psd = True + + def __init__(self, n_cons, n_freqs, n_times): + super(_CohEstBase, self).__init__(n_cons, n_freqs, n_times) + + # allocate space for accumulation of CSD + self._acc = np.zeros(self.csd_shape, dtype=np.complex128) + + def accumulate(self, con_idx, csd_xy): + """Accumulate CSD for some connections.""" + self._acc[con_idx] += csd_xy + + +class _CohEst(_CohEstBase): + """Coherence Estimator.""" + + name = 'Coherence' + + def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + csd_mean = self._acc[con_idx] / n_epochs + self.con_scores[con_idx] = np.abs(csd_mean) / np.sqrt(psd_xx * psd_yy) + + +class _CohyEst(_CohEstBase): + """Coherency Estimator.""" + + name = 'Coherency' + + def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape, + dtype=np.complex128) + csd_mean = self._acc[con_idx] / n_epochs + self.con_scores[con_idx] = csd_mean / np.sqrt(psd_xx * psd_yy) + + +class _ImCohEst(_CohEstBase): + """Imaginary Coherence Estimator.""" + + name = 'Imaginary Coherence' + + def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + csd_mean = self._acc[con_idx] / n_epochs + self.con_scores[con_idx] = np.imag(csd_mean) / np.sqrt(psd_xx * psd_yy) + + +class _PLVEst(_EpochMeanConEstBase): + """PLV Estimator.""" + + name = 'PLV' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_PLVEst, self).__init__(n_cons, n_freqs, n_times) + + # allocate accumulator + self._acc = np.zeros(self.csd_shape, dtype=np.complex128) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + self._acc[con_idx] += csd_xy / np.abs(csd_xy) + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + plv = np.abs(self._acc / n_epochs) + self.con_scores[con_idx] = plv + + +class _ciPLVEst(_EpochMeanConEstBase): + """corrected imaginary PLV Estimator.""" + + name = 'ciPLV' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_ciPLVEst, self).__init__(n_cons, n_freqs, n_times) + + # allocate accumulator + self._acc = np.zeros(self.csd_shape, dtype=np.complex128) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + self._acc[con_idx] += csd_xy / np.abs(csd_xy) + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + imag_plv = np.abs(np.imag(self._acc)) / n_epochs + real_plv = np.real(self._acc) / n_epochs + real_plv = np.clip(real_plv, -1, 1) # bounded from -1 to 1 + mask = (np.abs(real_plv) == 1) # avoid division by 0 + real_plv[mask] = 0 + corrected_imag_plv = imag_plv / np.sqrt(1 - real_plv ** 2) + self.con_scores[con_idx] = corrected_imag_plv + + +class _PLIEst(_EpochMeanConEstBase): + """PLI Estimator.""" + + name = 'PLI' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_PLIEst, self).__init__(n_cons, n_freqs, n_times) + + # allocate accumulator + self._acc = np.zeros(self.csd_shape) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + self._acc[con_idx] += np.sign(np.imag(csd_xy)) + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + pli_mean = self._acc[con_idx] / n_epochs + self.con_scores[con_idx] = np.abs(pli_mean) + + +class _PLIUnbiasedEst(_PLIEst): + """Unbiased PLI Square Estimator.""" + + name = 'Unbiased PLI Square' + accumulate_psd = False + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + pli_mean = self._acc[con_idx] / n_epochs + + # See Vinck paper Eq. (30) + con = (n_epochs * pli_mean ** 2 - 1) / (n_epochs - 1) + + self.con_scores[con_idx] = con + + +class _DPLIEst(_EpochMeanConEstBase): + """DPLI Estimator.""" + + name = 'DPLI' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_DPLIEst, self).__init__(n_cons, n_freqs, n_times) + + # allocate accumulator + self._acc = np.zeros(self.csd_shape) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + self._acc[con_idx] += np.heaviside(np.imag(csd_xy), 0.5) + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + + con = self._acc[con_idx] / n_epochs + + self.con_scores[con_idx] = con + + +class _WPLIEst(_EpochMeanConEstBase): + """WPLI Estimator.""" + + name = 'WPLI' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_WPLIEst, self).__init__(n_cons, n_freqs, n_times) + + # store both imag(csd) and abs(imag(csd)) + acc_shape = (2,) + self.csd_shape + self._acc = np.zeros(acc_shape) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + im_csd = np.imag(csd_xy) + self._acc[0, con_idx] += im_csd + self._acc[1, con_idx] += np.abs(im_csd) + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + + num = np.abs(self._acc[0, con_idx]) + denom = self._acc[1, con_idx] + + # handle zeros in denominator + z_denom = np.where(denom == 0.) + denom[z_denom] = 1. + + con = num / denom + + # where we had zeros in denominator, we set con to zero + con[z_denom] = 0. + + self.con_scores[con_idx] = con + + +class _WPLIDebiasedEst(_EpochMeanConEstBase): + """Debiased WPLI Square Estimator.""" + + name = 'Debiased WPLI Square' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_WPLIDebiasedEst, self).__init__(n_cons, n_freqs, n_times) + # store imag(csd), abs(imag(csd)), imag(csd)^2 + acc_shape = (3,) + self.csd_shape + self._acc = np.zeros(acc_shape) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + im_csd = np.imag(csd_xy) + self._acc[0, con_idx] += im_csd + self._acc[1, con_idx] += np.abs(im_csd) + self._acc[2, con_idx] += im_csd ** 2 + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + + # note: we use the trick from fieldtrip to compute the + # the estimate over all pairwise epoch combinations + sum_im_csd = self._acc[0, con_idx] + sum_abs_im_csd = self._acc[1, con_idx] + sum_sq_im_csd = self._acc[2, con_idx] + + denom = sum_abs_im_csd ** 2 - sum_sq_im_csd + + # handle zeros in denominator + z_denom = np.where(denom == 0.) + denom[z_denom] = 1. + + con = (sum_im_csd ** 2 - sum_sq_im_csd) / denom + + # where we had zeros in denominator, we set con to zero + con[z_denom] = 0. + + self.con_scores[con_idx] = con + + +class _PPCEst(_EpochMeanConEstBase): + """Pairwise Phase Consistency (PPC) Estimator.""" + + name = 'PPC' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_PPCEst, self).__init__(n_cons, n_freqs, n_times) + + # store csd / abs(csd) + self._acc = np.zeros(self.csd_shape, dtype=np.complex128) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + denom = np.abs(csd_xy) + z_denom = np.where(denom == 0.) + denom[z_denom] = 1. + this_acc = csd_xy / denom + this_acc[z_denom] = 0. # handle division by zero + + self._acc[con_idx] += this_acc + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + + # note: we use the trick from fieldtrip to compute the + # the estimate over all pairwise epoch combinations + con = ((self._acc[con_idx] * np.conj(self._acc[con_idx]) - n_epochs) / + (n_epochs * (n_epochs - 1.))) + + self.con_scores[con_idx] = np.real(con) + + +############################################################################### + + +# map names to estimator types +_CON_METHOD_MAP = {'coh': _CohEst, 'cohy': _CohyEst, 'imcoh': _ImCohEst, + 'plv': _PLVEst, 'ciplv': _ciPLVEst, 'ppc': _PPCEst, + 'pli': _PLIEst, 'pli2_unbiased': _PLIUnbiasedEst, + 'dpli': _DPLIEst, 'wpli': _WPLIEst, + 'wpli2_debiased': _WPLIDebiasedEst} + + +@ verbose +@ fill_doc +def spectral_connectivity_epochs( + data, names=None, method='coh', indices=None, sfreq=None, + mode='multitaper', fmin=None, fmax=np.inf, fskip=0, faverage=False, + tmin=None, tmax=None, mt_bandwidth=None, mt_adaptive=False, + mt_low_bias=True, cwt_freqs=None, cwt_n_cycles=7, block_size=1000, + n_jobs=1, verbose=None +): + """Compute bivariate (time-)frequency-domain connectivity measures. + + The connectivity method(s) are specified using the "method" parameter. + All methods are based on estimates of the cross- and power spectral + densities (CSD/PSD) Sxy and Sxx, Syy. + + Parameters + ---------- + data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs + The data from which to compute connectivity. Note that it is also + possible to combine multiple signals by providing a list of tuples, + e.g., data = [(arr_0, stc_0), (arr_1, stc_1), (arr_2, stc_2)], + corresponds to 3 epochs, and arr_* could be an array with the same + number of time points as stc_*. The array-like object can also + be a list/generator of array, shape =(n_signals, n_times), + or a list/generator of SourceEstimate or VolSourceEstimate objects. + %(names)s + method : str | list of str + Connectivity measure(s) to compute. These can be ``['coh', 'cohy', + 'imcoh', 'plv', 'ciplv', 'ppc', 'pli', 'dpli', 'wpli', + 'wpli2_debiased']``. + indices : tuple of array | None + Two arrays with indices of connections for which to compute + connectivity. Each array for the seeds and targets should contain the + channel indices for each bivariate connection. If ``None``, connections + between all channels are computed. + sfreq : float + The sampling frequency. Required if data is not + :class:`Epochs `. + mode : str + Spectrum estimation mode can be either: 'multitaper', 'fourier', or + 'cwt_morlet'. + fmin : float | tuple of float + The lower frequency of interest. Multiple bands are defined using + a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. + fmax : float | tuple of float + The upper frequency of interest. Multiple bands are dedined using + a tuple, e.g. (13., 30.) for two band with 13Hz and 30Hz upper freq. + fskip : int + Omit every "(fskip + 1)-th" frequency bin to decimate in frequency + domain. + faverage : bool + Average connectivity scores for each frequency band. If True, + the output freqs will be a list with arrays of the frequencies + that were averaged. + tmin : float | None + Time to start connectivity estimation. Note: when "data" is an array, + the first sample is assumed to be at time 0. For other types + (Epochs, etc.), the time information contained in the object is used + to compute the time indices. + tmax : float | None + Time to end connectivity estimation. Note: when "data" is an array, + the first sample is assumed to be at time 0. For other types + (Epochs, etc.), the time information contained in the object is used + to compute the time indices. + mt_bandwidth : float | None + The bandwidth of the multitaper windowing function in Hz. + Only used in 'multitaper' mode. + mt_adaptive : bool + Use adaptive weights to combine the tapered spectra into PSD. + Only used in 'multitaper' mode. + mt_low_bias : bool + Only use tapers with more than 90 percent spectral concentration + within bandwidth. Only used in 'multitaper' mode. + cwt_freqs : array + Array of frequencies of interest. Only used in 'cwt_morlet' mode. + cwt_n_cycles : float | array of float + Number of cycles. Fixed number or one per frequency. Only used in + 'cwt_morlet' mode. + block_size : int + How many connections to compute at once (higher numbers are faster + but require more memory). + n_jobs : int + How many samples to process in parallel. + %(verbose)s + + Returns + ------- + con : array | list of array + Computed connectivity measure(s). Either an instance of + ``SpectralConnectivity`` or ``SpectroTemporalConnectivity``. + The shape of the connectivity result will be: + + - ``(n_cons, n_freqs)`` for multitaper or fourier modes + - ``(n_cons, n_freqs, n_times)`` for cwt_morlet mode + - ``n_cons = n_signals ** 2`` with ``indices=None`` + - ``n_cons = len(indices[0])`` when indices is supplied. + + See Also + -------- + mne_connectivity.spectral_connectivity_epochs_multivariate + mne_connectivity.spectral_connectivity_time + mne_connectivity.SpectralConnectivity + mne_connectivity.SpectroTemporalConnectivity + + Notes + ----- + Please note that the interpretation of the measures in this function + depends on the data and underlying assumptions and does not necessarily + reflect a causal relationship between brain regions. + + These measures are not to be interpreted over time. Each Epoch passed into + the dataset is interpreted as an independent sample of the same + connectivity structure. Within each Epoch, it is assumed that the spectral + measure is stationary. The spectral measures implemented in this function + are computed across Epochs. **Thus, spectral measures computed with only + one Epoch will result in errorful values and spectral measures computed + with few Epochs will be unreliable.** Please see + ``spectral_connectivity_time`` for time-resolved connectivity estimation. + + The spectral densities can be estimated using a multitaper method with + digital prolate spheroidal sequence (DPSS) windows, a discrete Fourier + transform with Hanning windows, or a continuous wavelet transform using + Morlet wavelets. The spectral estimation mode is specified using the + "mode" parameter. + + By default, the connectivity between all signals is computed (only + connections corresponding to the lower-triangular part of the connectivity + matrix). If one is only interested in the connectivity between some + signals, the "indices" parameter can be used. For example, to compute the + connectivity between the signal with index 0 and signals "2, 3, 4" (a total + of 3 connections) one can use the following:: + + indices = (np.array([0, 0, 0]), # row indices + np.array([2, 3, 4])) # col indices + + con = spectral_connectivity_epochs(data, method='coh', + indices=indices, ...) + + In this case con.get_data().shape = (3, n_freqs). The connectivity scores + are in the same order as defined indices. + + **Supported Connectivity Measures** + + The connectivity method(s) is specified using the "method" parameter. The + following methods are supported (note: ``E[]`` denotes average over + epochs). Multiple measures can be computed at once by using a list/tuple, + e.g., ``['coh', 'pli']`` to compute coherence and PLI. + + 'coh' : Coherence given by:: + + | E[Sxy] | + C = --------------------- + sqrt(E[Sxx] * E[Syy]) + + 'cohy' : Coherency given by:: + + E[Sxy] + C = --------------------- + sqrt(E[Sxx] * E[Syy]) + + 'imcoh' : Imaginary coherence :footcite:`NolteEtAl2004` given by:: + + Im(E[Sxy]) + C = ---------------------- + sqrt(E[Sxx] * E[Syy]) + + 'plv' : Phase-Locking Value (PLV) :footcite:`LachauxEtAl1999` given + by:: + + PLV = |E[Sxy/|Sxy|]| + + 'ciplv' : corrected imaginary PLV (ciPLV) + :footcite:`BrunaEtAl2018` given by:: + + |E[Im(Sxy/|Sxy|)]| + ciPLV = ------------------------------------ + sqrt(1 - |E[real(Sxy/|Sxy|)]| ** 2) + + 'ppc' : Pairwise Phase Consistency (PPC), an unbiased estimator + of squared PLV :footcite:`VinckEtAl2010`. + + 'pli' : Phase Lag Index (PLI) :footcite:`StamEtAl2007` given by:: + + PLI = |E[sign(Im(Sxy))]| + + 'pli2_unbiased' : Unbiased estimator of squared PLI + :footcite:`VinckEtAl2011`. + + 'dpli' : Directed Phase Lag Index (DPLI) :footcite:`StamEtAl2012` + given by (where H is the Heaviside function):: + + DPLI = E[H(Im(Sxy))] + + 'wpli' : Weighted Phase Lag Index (WPLI) :footcite:`VinckEtAl2011` + given by:: + + |E[Im(Sxy)]| + WPLI = ------------------ + E[|Im(Sxy)|] + + 'wpli2_debiased' : Debiased estimator of squared WPLI + :footcite:`VinckEtAl2011`. + + References + ---------- + .. footbibliography:: + """ + ( + fmin, fmax, n_bands, method, con_method_types, accumulate_psd, + parallel, my_epoch_spectral_connectivity + ) = _check_spectral_connectivity_epochs_settings( + method, fmin, fmax, n_jobs, verbose, _CON_METHOD_MAP) + + (names, times_in, sfreq, events, event_id, + metadata) = _check_spectral_connectivity_epochs_data(data, sfreq, names) + + # loop over data; it could be a generator that returns + # (n_signals x n_times) arrays or SourceEstimates + epoch_idx = 0 + logger.info('Connectivity computation...') + warn_times = True + for epoch_block in _get_n_epochs(data, n_jobs): + if epoch_idx == 0: + # initialize everything times and frequencies + (times, n_times, times_in, n_times_in, tmin_idx, tmax_idx, n_freqs, + freq_mask, freqs, freqs_bands, freq_idx_bands, n_signals, + warn_times) = _prepare_connectivity( + epoch_block=epoch_block, times_in=times_in, tmin=tmin, + tmax=tmax, fmin=fmin, fmax=fmax, sfreq=sfreq, mode=mode, + fskip=fskip, n_bands=n_bands, cwt_freqs=cwt_freqs, + faverage=faverage) + + # check indices input + n_cons, indices_use = _check_indices(indices, n_signals) + + # get the window function, wavelets, etc for different modes + (spectral_params, mt_adaptive, n_times_spectrum, + n_tapers) = _assemble_spectral_params( + mode=mode, n_times=n_times, mt_adaptive=mt_adaptive, + mt_bandwidth=mt_bandwidth, sfreq=sfreq, + mt_low_bias=mt_low_bias, cwt_n_cycles=cwt_n_cycles, + cwt_freqs=cwt_freqs, freqs=freqs, freq_mask=freq_mask) + + # unique signals for which we actually need to compute CSD/PSD + sig_idx = np.unique(np.r_[indices_use[0], indices_use[1]]) + n_signals_use = len(sig_idx) + + # map indices to unique indices + idx_map = [np.searchsorted(sig_idx, ind) for ind in indices_use] + + # allocate space to accumulate PSD + if accumulate_psd: + if n_times_spectrum == 0: + psd_shape = (n_signals_use, n_freqs) + else: + psd_shape = (n_signals_use, n_freqs, n_times_spectrum) + psd = np.zeros(psd_shape) + else: + psd = None + + # create instances of the connectivity estimators + con_methods = [] + for mtype in con_method_types: + con_methods.append(mtype(n_cons=n_cons, n_freqs=n_freqs, + n_times=n_times_spectrum)) + + sep = ', ' + metrics_str = sep.join([meth.name for meth in con_methods]) + logger.info(' the following metrics will be computed: %s' + % metrics_str) + + call_params = dict( + sig_idx=sig_idx, tmin_idx=tmin_idx, tmax_idx=tmax_idx, sfreq=sfreq, + method=method, mode=mode, freq_mask=freq_mask, idx_map=idx_map, + n_cons=n_cons, block_size=block_size, + psd=psd, accumulate_psd=accumulate_psd, + mt_adaptive=mt_adaptive, + con_method_types=con_method_types, + con_methods=con_methods if n_jobs == 1 else None, + n_signals=n_signals, n_signals_use=n_signals_use, n_times=n_times, + gc_n_lags=None, multivariate_con=False, + accumulate_inplace=True if n_jobs == 1 else False) + call_params.update(**spectral_params) + + epoch_idx = _compute_spectral_methods_epochs( + con_methods, epoch_block, epoch_idx, call_params, parallel, + my_epoch_spectral_connectivity, n_jobs, n_times_in, times_in, + warn_times) + + # normalize + n_epochs = epoch_idx + if accumulate_psd: + psd /= n_epochs + + # compute final connectivity scores + con = list() + for conn_method in con_methods: + + # future estimators will need to be handled here + if conn_method.accumulate_psd: + # compute scores block-wise to save memory + for i in range(0, n_cons, block_size): + con_idx = slice(i, i + block_size) + psd_xx = psd[idx_map[0][con_idx]] + psd_yy = psd[idx_map[1][con_idx]] + conn_method.compute_con(con_idx, n_epochs, psd_xx, psd_yy) + else: + # compute all scores at once + conn_method.compute_con(slice(0, n_cons), n_epochs) + + # get the connectivity scores + this_con = conn_method.con_scores + + if this_con.shape[0] != n_cons: + raise RuntimeError( + 'first dimension of connectivity scores does not match the ' + 'number of connections; please contact the mne-connectivity ' + 'developers') + if faverage: + if this_con.shape[1] != n_freqs: + raise RuntimeError( + 'second dimension of connectivity scores does not match ' + 'the number of frequencies; please contact the ' + 'mne-connectivity developers') + con_shape = (n_cons, n_bands) + this_con.shape[2:] + this_con_bands = np.empty(con_shape, dtype=this_con.dtype) + for band_idx in range(n_bands): + this_con_bands[:, band_idx] = np.mean( + this_con[:, freq_idx_bands[band_idx]], axis=1) + this_con = this_con_bands + + con.append(this_con) + # No patterns for bivariate connectivity + patterns = [None for _ in range(len(con))] + + # return all-to-all connectivity matrices raveled into a 1D array + if indices is None: + logger.info(' assembling connectivity matrix') + con_flat = con + con = list() + for this_con_flat in con_flat: + this_con = np.zeros((n_signals, n_signals) + + this_con_flat.shape[1:], + dtype=this_con_flat.dtype) + this_con[indices_use] = this_con_flat + + # ravel 2D connectivity into a 1D array + # while keeping other dimensions + this_con = this_con.reshape((n_signals ** 2,) + + this_con_flat.shape[1:]) + con.append(this_con) + + conn_list = _store_results( + con=con, patterns=patterns, method=method, freqs=freqs, + faverage=faverage, freqs_bands=freqs_bands, names=names, mode=mode, + indices=indices, n_epochs=n_epochs, times=times, n_tapers=n_tapers, + metadata=metadata, events=events, event_id=event_id, rank=None, + gc_n_lags=None, n_signals=n_signals) + + return conn_list diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py new file mode 100644 index 00000000..28077adb --- /dev/null +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -0,0 +1,1129 @@ +# Authors: Martin Luessi +# Denis A. Engemann +# Adam Li +# Thomas S. Binns +# Tien D. Nguyen +# Richard M. Köhler +# +# License: BSD (3-clause) + +import numpy as np +import scipy as sp +from mne.epochs import BaseEpochs +from mne.parallel import parallel_func +from mne.utils import ProgressBar, logger, verbose + +from .epochs import ( + _AbstractConEstBase, _check_spectral_connectivity_epochs_settings, + _check_spectral_connectivity_epochs_data, _get_n_epochs, + _prepare_connectivity, _assemble_spectral_params, + _compute_spectral_methods_epochs, _store_results) +from ..utils import fill_doc, check_multivariate_indices + + +def _check_indices(indices, method, n_signals): + if indices is None: + if any(this_method in _gc_methods for this_method in method): + raise ValueError( + 'indices must be specified when computing Granger causality, ' + 'as all-to-all connectivity is not supported') + else: + logger.info('using all indices for multivariate connectivity') + indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], + np.arange(n_signals, dtype=int)[np.newaxis, :]) + else: + indices_use = check_multivariate_indices(indices) # pad with -1 + if any(this_method in _gc_methods for this_method in method): + for seed, target in zip(indices[0], indices[1]): + intersection = np.intersect1d(seed, target) + if np.any(intersection != -1): # ignore padded entries + raise ValueError( + 'seed and target indices must not intersect when ' + 'computing Granger causality') + + # number of connectivities to compute + n_cons = len(indices_use[0]) + logger.info(' computing connectivity for %d connections' % n_cons) + + return n_cons, indices_use + + +def _check_rank_input(rank, data, indices): + """Check the rank argument is appropriate and compute rank if missing.""" + sv_tol = 1e-10 # tolerance for non-zero singular val (rel. to largest) + if rank is None: + rank = np.zeros((2, len(indices[0])), dtype=int) + + if isinstance(data, BaseEpochs): + data_arr = data.get_data() + else: + data_arr = data + + # XXX: Unpadding of arrays after already padding them is perhaps not so + # efficient. However, we need to remove the padded values to + # ensure only the correct channels are indexed, and having two + # versions of indices is a bit messy currently. A candidate for + # refactoring to simplify code. + + for group_i in range(2): # seeds and targets + for con_i, con_idcs in enumerate(indices[group_i]): + con_idcs = con_idcs[con_idcs != -1] # -1 is padded value + s = np.linalg.svd(data_arr[:, con_idcs], compute_uv=False) + rank[group_i][con_i] = np.min( + [np.count_nonzero(epoch >= epoch[0] * sv_tol) + for epoch in s]) + + logger.info('Estimated data ranks:') + con_i = 1 + for seed_rank, target_rank in zip(rank[0], rank[1]): + logger.info(' connection %i - seeds (%i); targets (%i)' + % (con_i, seed_rank, target_rank, )) + con_i += 1 + rank = tuple((np.array(rank[0]), np.array(rank[1]))) + + else: + for seed_idcs, target_idcs, seed_rank, target_rank in zip( + indices[0], indices[1], rank[0], rank[1]): + if not (0 < seed_rank <= len(seed_idcs) and + 0 < target_rank <= len(target_idcs)): + raise ValueError( + 'ranks for seeds and targets must be > 0 and <= the ' + 'number of channels in the seeds and targets, ' + 'respectively, for each connection') + + return rank + + +######################################################################## +# Multivariate connectivity estimators + +class _EpochMeanMultivariateConEstBase(_AbstractConEstBase): + """Base class for mean epoch-wise multivar. con. estimation methods.""" + + n_steps = None + patterns = None + + def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): + self.n_signals = n_signals + self.n_cons = n_cons + self.n_freqs = n_freqs + self.n_times = n_times + self.n_jobs = n_jobs + + # include time dimension, even when unused for indexing flexibility + if n_times == 0: + self.csd_shape = (n_signals**2, n_freqs) + self.con_scores = np.zeros((n_cons, n_freqs, 1)) + else: + self.csd_shape = (n_signals**2, n_freqs, n_times) + self.con_scores = np.zeros((n_cons, n_freqs, n_times)) + + # allocate space for accumulation of CSD + self._acc = np.zeros(self.csd_shape, dtype=np.complex128) + + self._compute_n_progress_bar_steps() + + def start_epoch(self): # noqa: D401 + """Called at the start of each epoch.""" + pass # for this type of con. method we don't do anything + + def combine(self, other): + """Include con. accumulated for some epochs in this estimate.""" + self._acc += other._acc + + def accumulate(self, con_idx, csd_xy): + """Accumulate CSD for some connections.""" + self._acc[con_idx] += csd_xy + + def _compute_n_progress_bar_steps(self): + """Calculate the number of steps to include in the progress bar.""" + self.n_steps = int(np.ceil(self.n_freqs / self.n_jobs)) + + def _log_connection_number(self, con_i): + """Log the number of the connection being computed.""" + logger.info('Computing %s for connection %i of %i' + % (self.name, con_i + 1, self.n_cons, )) + + def _get_block_indices(self, block_i, limit): + """Get indices for a computation block capped by a limit.""" + indices = np.arange(block_i * self.n_jobs, (block_i + 1) * self.n_jobs) + + return indices[np.nonzero(indices < limit)] + + def reshape_csd(self): + """Reshape CSD into a matrix of times x freqs x signals x signals.""" + if self.n_times == 0: + return (np.reshape(self._acc, ( + self.n_signals, self.n_signals, self.n_freqs, 1) + ).transpose(3, 2, 0, 1)) + + return (np.reshape(self._acc, ( + self.n_signals, self.n_signals, self.n_freqs, self.n_times) + ).transpose(3, 2, 0, 1)) + + +class _MultivariateCohEstBase(_EpochMeanMultivariateConEstBase): + """Base estimator for multivariate imag. part of coherency methods. + + See Ewald et al. (2012). NeuroImage. DOI: 10.1016/j.neuroimage.2011.11.084 + for equation references. + """ + + name = None + accumulate_psd = False + + def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): + super(_MultivariateCohEstBase, self).__init__( + n_signals, n_cons, n_freqs, n_times, n_jobs) + + def compute_con(self, indices, ranks, n_epochs=1): + """Compute multivariate imag. part of coherency between signals.""" + assert self.name in ['MIC', 'MIM'], ( + 'the class name is not recognised, please contact the ' + 'mne-connectivity developers') + + csd = self.reshape_csd() / n_epochs + n_times = csd.shape[0] + times = np.arange(n_times) + freqs = np.arange(self.n_freqs) + + if self.name == 'MIC': + self.patterns = np.full( + (2, self.n_cons, indices[0].shape[1], self.n_freqs, n_times), + np.nan) + + con_i = 0 + for seed_idcs, target_idcs, seed_rank, target_rank in zip( + indices[0], indices[1], ranks[0], ranks[1]): + self._log_connection_number(con_i) + + seed_idcs = seed_idcs[seed_idcs != -1] + target_idcs = target_idcs[target_idcs != -1] + con_idcs = [*seed_idcs, *target_idcs] + + C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] + + # Eqs. 32 & 33 + C_bar, U_bar_aa, U_bar_bb = self._csd_svd( + C, seed_idcs, seed_rank, target_rank) + + # Eqs. 3 & 4 + E = self._compute_e(C_bar, n_seeds=U_bar_aa.shape[3]) + + if self.name == 'MIC': + self._compute_mic(E, C, seed_idcs, target_idcs, n_times, + U_bar_aa, U_bar_bb, con_i) + else: + self._compute_mim(E, seed_idcs, target_idcs, con_i) + + con_i += 1 + + self.reshape_results() + + def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): + """Dimensionality reduction of CSD with SVD.""" + n_times = csd.shape[0] + n_seeds = len(seed_idcs) + n_targets = csd.shape[3] - n_seeds + + C_aa = csd[..., :n_seeds, :n_seeds] + C_ab = csd[..., :n_seeds, n_seeds:] + C_bb = csd[..., n_seeds:, n_seeds:] + C_ba = csd[..., n_seeds:, :n_seeds] + + # Eq. 32 + if seed_rank != n_seeds: + U_aa = np.linalg.svd(np.real(C_aa), full_matrices=False)[0] + U_bar_aa = U_aa[..., :seed_rank] + else: + U_bar_aa = np.broadcast_to( + np.identity(n_seeds), + (n_times, self.n_freqs) + (n_seeds, n_seeds)) + + if target_rank != n_targets: + U_bb = np.linalg.svd(np.real(C_bb), full_matrices=False)[0] + U_bar_bb = U_bb[..., :target_rank] + else: + U_bar_bb = np.broadcast_to( + np.identity(n_targets), + (n_times, self.n_freqs) + (n_targets, n_targets)) + + # Eq. 33 + C_bar_aa = np.matmul( + U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_aa, U_bar_aa)) + C_bar_ab = np.matmul( + U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_ab, U_bar_bb)) + C_bar_bb = np.matmul( + U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_bb, U_bar_bb)) + C_bar_ba = np.matmul( + U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_ba, U_bar_aa)) + C_bar = np.append(np.append(C_bar_aa, C_bar_ab, axis=3), + np.append(C_bar_ba, C_bar_bb, axis=3), axis=2) + + return C_bar, U_bar_aa, U_bar_bb + + def _compute_e(self, csd, n_seeds): + """Compute E from the CSD.""" + C_r = np.real(csd) + + parallel, parallel_compute_t, _ = parallel_func( + _mic_mim_compute_t, self.n_jobs, verbose=False) + + # imag. part of T filled when data is rank-deficient + T = np.zeros(csd.shape, dtype=np.complex128) + for block_i in ProgressBar( + range(self.n_steps), mesg="frequency blocks"): + freqs = self._get_block_indices(block_i, self.n_freqs) + T[:, freqs] = np.array(parallel(parallel_compute_t( + C_r[:, f], T[:, f], n_seeds) for f in freqs) + ).transpose(1, 0, 2, 3) + + if not np.isreal(T).all() or not np.isfinite(T).all(): + raise RuntimeError( + 'the transformation matrix of the data must be real-valued ' + 'and contain no NaN or infinity values; check that you are ' + 'using full rank data or specify an appropriate rank for the ' + 'seeds and targets that is less than or equal to their ranks') + T = np.real(T) # make T real if check passes + + # Eq. 4 + D = np.matmul(T, np.matmul(csd, T)) + + # E as imag. part of D between seeds and targets + return np.imag(D[..., :n_seeds, n_seeds:]) + + def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, + U_bar_bb, con_i): + """Compute MIC and the associated spatial patterns.""" + n_seeds = len(seed_idcs) + n_targets = len(target_idcs) + times = np.arange(n_times) + freqs = np.arange(self.n_freqs) + + # Eigendecomp. to find spatial filters for seeds and targets + w_seeds, V_seeds = np.linalg.eigh( + np.matmul(E, E.transpose(0, 1, 3, 2))) + w_targets, V_targets = np.linalg.eigh( + np.matmul(E.transpose(0, 1, 3, 2), E)) + if ( + len(seed_idcs) == len(target_idcs) and + np.all(np.sort(seed_idcs) == np.sort(target_idcs)) + ): + # strange edge-case where the eigenvectors returned should be a set + # of identity matrices with one rotated by 90 degrees, but are + # instead identical (i.e. are not rotated versions of one another). + # This leads to the case where the spatial filters are incorrectly + # applied, resulting in connectivity estimates of ~0 when they + # should be perfectly correlated ~1. Accordingly, we manually + # create a set of rotated identity matrices to use as the filters. + create_filter = False + stop = False + while not create_filter and not stop: + for time_i in range(n_times): + for freq_i in range(self.n_freqs): + if np.all(V_seeds[time_i, freq_i] == + V_targets[time_i, freq_i]): + create_filter = True + break + stop = True + if create_filter: + n_chans = E.shape[2] + eye_4d = np.zeros_like(V_seeds) + eye_4d[:, :, np.arange(n_chans), np.arange(n_chans)] = 1 + V_seeds = eye_4d + V_targets = np.rot90(eye_4d, axes=(2, 3)) + + # Spatial filters with largest eigval. for seeds and targets + alpha = V_seeds[times[:, None], freqs, :, w_seeds.argmax(axis=2)] + beta = V_targets[times[:, None], freqs, :, w_targets.argmax(axis=2)] + + # Eq. 46 (seed spatial patterns) + self.patterns[0, con_i, :n_seeds] = (np.matmul( + np.real(C[..., :n_seeds, :n_seeds]), + np.matmul(U_bar_aa, np.expand_dims(alpha, axis=3))))[..., 0].T + + # Eq. 47 (target spatial patterns) + self.patterns[1, con_i, :n_targets] = (np.matmul( + np.real(C[..., n_seeds:, n_seeds:]), + np.matmul(U_bar_bb, np.expand_dims(beta, axis=3))))[..., 0].T + + # Eq. 7 + self.con_scores[con_i] = (np.einsum( + 'ijk,ijk->ij', alpha, np.matmul(E, np.expand_dims( + beta, axis=3))[..., 0] + ) / np.linalg.norm(alpha, axis=2) * np.linalg.norm(beta, axis=2)).T + + def _compute_mim(self, E, seed_idcs, target_idcs, con_i): + """Compute MIM (a.k.a. GIM if seeds == targets).""" + # Eq. 14 + self.con_scores[con_i] = np.matmul( + E, E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T + + # Eq. 15 + if ( + len(seed_idcs) == len(target_idcs) and + np.all(np.sort(seed_idcs) == np.sort(target_idcs)) + ): + self.con_scores[con_i] *= 0.5 + + def reshape_results(self): + """Remove time dimension from results, if necessary.""" + if self.n_times == 0: + self.con_scores = self.con_scores[..., 0] + if self.patterns is not None: + self.patterns = self.patterns[..., 0] + + +def _mic_mim_compute_t(C, T, n_seeds): + """Compute T for a single frequency (used for MIC and MIM).""" + for time_i in range(C.shape[0]): + T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power( + C[time_i, :n_seeds, :n_seeds], -0.5 + ) + T[time_i, n_seeds:, n_seeds:] = sp.linalg.fractional_matrix_power( + C[time_i, n_seeds:, n_seeds:], -0.5 + ) + + return T + + +class _MICEst(_MultivariateCohEstBase): + """Multivariate imaginary part of coherency (MIC) estimator.""" + + name = "MIC" + + +class _MIMEst(_MultivariateCohEstBase): + """Multivariate interaction measure (MIM) estimator.""" + + name = "MIM" + + +class _GCEstBase(_EpochMeanMultivariateConEstBase): + """Base multivariate state-space Granger causality estimator.""" + + accumulate_psd = False + + def __init__(self, n_signals, n_cons, n_freqs, n_times, n_lags, n_jobs=1): + super(_GCEstBase, self).__init__( + n_signals, n_cons, n_freqs, n_times, n_jobs) + + self.freq_res = (self.n_freqs - 1) * 2 + if n_lags >= self.freq_res: + raise ValueError( + 'the number of lags (%i) must be less than double the ' + 'frequency resolution (%i)' % (n_lags, self.freq_res, )) + self.n_lags = n_lags + + def compute_con(self, indices, ranks, n_epochs=1): + """Compute multivariate state-space Granger causality.""" + assert self.name in ['GC', 'GC time-reversed'], ( + 'the class name is not recognised, please contact the ' + 'mne-connectivity developers') + + csd = self.reshape_csd() / n_epochs + + n_times = csd.shape[0] + times = np.arange(n_times) + freqs = np.arange(self.n_freqs) + + con_i = 0 + for seed_idcs, target_idcs, seed_rank, target_rank in zip( + indices[0], indices[1], ranks[0], ranks[1]): + self._log_connection_number(con_i) + + seed_idcs = seed_idcs[seed_idcs != -1] + target_idcs = target_idcs[target_idcs != -1] + con_idcs = [*seed_idcs, *target_idcs] + + C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] + + C_bar = self._csd_svd(C, seed_idcs, seed_rank, target_rank) + n_signals = seed_rank + target_rank + con_seeds = np.arange(seed_rank) + con_targets = np.arange(target_rank) + seed_rank + + autocov = self._compute_autocov(C_bar) + if self.name == "GC time-reversed": + autocov = autocov.transpose(0, 1, 3, 2) + + A_f, V = self._autocov_to_full_var(autocov) + A_f_3d = np.reshape( + A_f, (n_times, n_signals, n_signals * self.n_lags), order="F") + A, K = self._full_var_to_iss(A_f_3d) + + self.con_scores[con_i] = self._iss_to_ugc( + A, A_f_3d, K, V, con_seeds, con_targets) + + con_i += 1 + + self.reshape_results() + + def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): + """Dimensionality reduction of CSD with SVD on the covariance.""" + # sum over times and epochs to get cov. from CSD + cov = csd.sum(axis=(0, 1)) + + n_seeds = len(seed_idcs) + n_targets = csd.shape[3] - n_seeds + + cov_aa = cov[:n_seeds, :n_seeds] + cov_bb = cov[n_seeds:, n_seeds:] + + if seed_rank != n_seeds: + U_aa = np.linalg.svd(np.real(cov_aa), full_matrices=False)[0] + U_bar_aa = U_aa[:, :seed_rank] + else: + U_bar_aa = np.identity(n_seeds) + + if target_rank != n_targets: + U_bb = np.linalg.svd(np.real(cov_bb), full_matrices=False)[0] + U_bar_bb = U_bb[:, :target_rank] + else: + U_bar_bb = np.identity(n_targets) + + C_aa = csd[..., :n_seeds, :n_seeds] + C_ab = csd[..., :n_seeds, n_seeds:] + C_bb = csd[..., n_seeds:, n_seeds:] + C_ba = csd[..., n_seeds:, :n_seeds] + + C_bar_aa = np.matmul( + U_bar_aa.transpose(1, 0), np.matmul(C_aa, U_bar_aa)) + C_bar_ab = np.matmul( + U_bar_aa.transpose(1, 0), np.matmul(C_ab, U_bar_bb)) + C_bar_bb = np.matmul( + U_bar_bb.transpose(1, 0), np.matmul(C_bb, U_bar_bb)) + C_bar_ba = np.matmul( + U_bar_bb.transpose(1, 0), np.matmul(C_ba, U_bar_aa)) + C_bar = np.append(np.append(C_bar_aa, C_bar_ab, axis=3), + np.append(C_bar_ba, C_bar_bb, axis=3), axis=2) + + return C_bar + + def _compute_autocov(self, csd): + """Compute autocovariance from the CSD.""" + n_times = csd.shape[0] + n_signals = csd.shape[2] + + circular_shifted_csd = np.concatenate( + [np.flip(np.conj(csd[:, 1:]), axis=1), csd[:, :-1]], axis=1) + ifft_shifted_csd = self._block_ifft( + circular_shifted_csd, self.freq_res) + lags_ifft_shifted_csd = np.reshape( + ifft_shifted_csd[:, :self.n_lags + 1], + (n_times, self.n_lags + 1, n_signals ** 2), order="F") + + signs = np.repeat([1], self.n_lags + 1).tolist() + signs[1::2] = [x * -1 for x in signs[1::2]] + sign_matrix = np.repeat( + np.tile(np.array(signs), (n_signals ** 2, 1))[np.newaxis], + n_times, axis=0).transpose(0, 2, 1) + + return np.real(np.reshape( + sign_matrix * lags_ifft_shifted_csd, + (n_times, self.n_lags + 1, n_signals, n_signals), order="F")) + + def _block_ifft(self, csd, n_points): + """Compute block iFFT with n points.""" + shape = csd.shape + csd_3d = np.reshape( + csd, (shape[0], shape[1], shape[2] * shape[3]), order="F") + + csd_ifft = np.fft.ifft(csd_3d, n=n_points, axis=1) + + return np.reshape(csd_ifft, shape, order="F") + + def _autocov_to_full_var(self, autocov): + """Compute full VAR model using Whittle's LWR recursion.""" + if np.any(np.linalg.det(autocov) == 0): + raise RuntimeError( + 'the autocovariance matrix is singular; check if your data is ' + 'rank deficient and specify an appropriate rank argument <= ' + 'the rank of the seeds and targets') + + A_f, V = self._whittle_lwr_recursion(autocov) + + if not np.isfinite(A_f).all(): + raise RuntimeError('at least one VAR model coefficient is ' + 'infinite or NaN; check the data you are using') + + try: + np.linalg.cholesky(V) + except np.linalg.LinAlgError as np_error: + raise RuntimeError( + 'the covariance matrix of the residuals is not ' + 'positive-definite; check the singular values of your data ' + 'and specify an appropriate rank argument <= the rank of the ' + 'seeds and targets') from np_error + + return A_f, V + + def _whittle_lwr_recursion(self, G): + """Solve Yule-Walker eqs. for full VAR params. with LWR recursion. + + See: Whittle P., 1963. Biometrika, DOI: 10.1093/biomet/50.1-2.129 + """ + # Initialise recursion + n = G.shape[2] # number of signals + q = G.shape[1] - 1 # number of lags + t = G.shape[0] # number of times + qn = n * q + + cov = G[:, 0, :, :] # covariance + G_f = np.reshape( + G[:, 1:, :, :].transpose(0, 3, 1, 2), (t, qn, n), + order="F") # forward autocov + G_b = np.reshape( + np.flip(G[:, 1:, :, :], 1).transpose(0, 3, 2, 1), (t, n, qn), + order="F").transpose(0, 2, 1) # backward autocov + + A_f = np.zeros((t, n, qn)) # forward coefficients + A_b = np.zeros((t, n, qn)) # backward coefficients + + k = 1 # model order + r = q - k + k_f = np.arange(k * n) # forward indices + k_b = np.arange(r * n, qn) # backward indices + + try: + A_f[:, :, k_f] = np.linalg.solve( + cov, G_b[:, k_b, :].transpose(0, 2, 1)).transpose(0, 2, 1) + A_b[:, :, k_b] = np.linalg.solve( + cov, G_f[:, k_f, :].transpose(0, 2, 1)).transpose(0, 2, 1) + + # Perform recursion + for k in np.arange(2, q + 1): + var_A = (G_b[:, (r - 1) * n: r * n, :] - + np.matmul(A_f[:, :, k_f], G_b[:, k_b, :])) + var_B = cov - np.matmul(A_b[:, :, k_b], G_b[:, k_b, :]) + AA_f = np.linalg.solve( + var_B, var_A.transpose(0, 2, 1)).transpose(0, 2, 1) + + var_A = (G_f[:, (k - 1) * n: k * n, :] - + np.matmul(A_b[:, :, k_b], G_f[:, k_f, :])) + var_B = cov - np.matmul(A_f[:, :, k_f], G_f[:, k_f, :]) + AA_b = np.linalg.solve( + var_B, var_A.transpose(0, 2, 1)).transpose(0, 2, 1) + + A_f_previous = A_f[:, :, k_f] + A_b_previous = A_b[:, :, k_b] + + r = q - k + k_f = np.arange(k * n) + k_b = np.arange(r * n, qn) + + A_f[:, :, k_f] = np.dstack( + (A_f_previous - np.matmul(AA_f, A_b_previous), AA_f)) + A_b[:, :, k_b] = np.dstack( + (AA_b, A_b_previous - np.matmul(AA_b, A_f_previous))) + except np.linalg.LinAlgError as np_error: + raise RuntimeError( + 'the autocovariance matrix is singular; check if your data is ' + 'rank deficient and specify an appropriate rank argument <= ' + 'the rank of the seeds and targets') from np_error + + V = cov - np.matmul(A_f, G_f) + A_f = np.reshape(A_f, (t, n, n, q), order="F") + + return A_f, V + + def _full_var_to_iss(self, A_f): + """Compute innovations-form parameters for a state-space model. + + Parameters computed from a full VAR model using Aoki's method. For a + non-moving-average full VAR model, the state-space parameter C + (observation matrix) is identical to AF of the VAR model. + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101. + """ + t = A_f.shape[0] + m = A_f.shape[1] # number of signals + p = A_f.shape[2] // m # number of autoregressive lags + + I_p = np.dstack(t * [np.eye(m * p)]).transpose(2, 0, 1) + A = np.hstack((A_f, I_p[:, : (m * p - m), :])) # state transition + # matrix + K = np.hstack(( + np.dstack(t * [np.eye(m)]).transpose(2, 0, 1), + np.zeros((t, (m * (p - 1)), m)))) # Kalman gain matrix + + return A, K + + def _iss_to_ugc(self, A, C, K, V, seeds, targets): + """Compute unconditional GC from innovations-form state-space params. + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101. + """ + times = np.arange(A.shape[0]) + freqs = np.arange(self.n_freqs) + z = np.exp(-1j * np.pi * np.linspace(0, 1, self.n_freqs)) # points + # on a unit circle in the complex plane, one for each frequency + + H = self._iss_to_tf(A, C, K, z) # spectral transfer function + V_22_1 = np.linalg.cholesky(self._partial_covar(V, seeds, targets)) + HV = np.matmul(H, np.linalg.cholesky(V)) + S = np.matmul(HV, HV.conj().transpose(0, 1, 3, 2)) # Eq. 6 + S_11 = S[np.ix_(freqs, times, targets, targets)] + HV_12 = np.matmul(H[np.ix_(freqs, times, targets, seeds)], V_22_1) + HVH = np.matmul(HV_12, HV_12.conj().transpose(0, 1, 3, 2)) + + # Eq. 11 + return np.real( + np.log(np.linalg.det(S_11)) - np.log(np.linalg.det(S_11 - HVH))) + + def _iss_to_tf(self, A, C, K, z): + """Compute transfer function for innovations-form state-space params. + + In the frequency domain, the back-shift operator, z, is a vector of + points on a unit circle in the complex plane. z = e^-iw, where -pi < w + <= pi. + + A note on efficiency: solving over the 4D time-freq. tensor is slower + than looping over times and freqs when n_times and n_freqs high, and + when n_times and n_freqs low, looping over times and freqs very fast + anyway (plus tensor solving doesn't allow for parallelisation). + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101. + """ + t = A.shape[0] + h = self.n_freqs + n = C.shape[1] + m = A.shape[1] + I_n = np.eye(n) + I_m = np.eye(m) + H = np.zeros((h, t, n, n), dtype=np.complex128) + + parallel, parallel_compute_H, _ = parallel_func( + _gc_compute_H, self.n_jobs, verbose=False + ) + H = np.zeros((h, t, n, n), dtype=np.complex128) + for block_i in ProgressBar( + range(self.n_steps), mesg="frequency blocks" + ): + freqs = self._get_block_indices(block_i, self.n_freqs) + H[freqs] = parallel( + parallel_compute_H(A, C, K, z[k], I_n, I_m) for k in freqs) + + return H + + def _partial_covar(self, V, seeds, targets): + """Compute partial covariance of a matrix. + + Given a covariance matrix V, the partial covariance matrix of V between + indices i and j, given k (V_ij|k), is equivalent to V_ij - V_ik * + V_kk^-1 * V_kj. In this case, i and j are seeds, and k are targets. + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101. + """ + times = np.arange(V.shape[0]) + W = np.linalg.solve( + np.linalg.cholesky(V[np.ix_(times, targets, targets)]), + V[np.ix_(times, targets, seeds)], + ) + W = np.matmul(W.transpose(0, 2, 1), W) + + return V[np.ix_(times, seeds, seeds)] - W + + def reshape_results(self): + """Remove time dimension from con. scores, if necessary.""" + if self.n_times == 0: + self.con_scores = self.con_scores[:, :, 0] + + +def _gc_compute_H(A, C, K, z_k, I_n, I_m): + """Compute transfer function for innovations-form state-space params. + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101, Eq. 4. + """ + from scipy import linalg # XXX: is this necessary??? + H = np.zeros((A.shape[0], C.shape[1], C.shape[1]), dtype=np.complex128) + for t in range(A.shape[0]): + H[t] = I_n + np.matmul( + C[t], linalg.lu_solve(linalg.lu_factor(z_k * I_m - A[t]), K[t])) + + return H + + +class _GCEst(_GCEstBase): + """[seeds -> targets] state-space GC estimator.""" + + name = "GC" + + +class _GCTREst(_GCEstBase): + """time-reversed[seeds -> targets] state-space GC estimator.""" + + name = "GC time-reversed" + +############################################################################### + + +# map names to estimator types +_CON_METHOD_MAP = {'mic': _MICEst, 'mim': _MIMEst, 'gc': _GCEst, + 'gc_tr': _GCTREst} + +_gc_methods = ['gc', 'gc_tr'] + + +@ verbose +@ fill_doc +def spectral_connectivity_epochs_multivariate( + data, names=None, method='mic', indices=None, sfreq=None, + mode='multitaper', fmin=None, fmax=np.inf, fskip=0, faverage=False, + tmin=None, tmax=None, mt_bandwidth=None, mt_adaptive=False, + mt_low_bias=True, cwt_freqs=None, cwt_n_cycles=7, gc_n_lags=40, rank=None, + block_size=1000, n_jobs=1, verbose=None +): + r"""Compute multivariate (time-)frequency-domain connectivity measures. + + The connectivity method(s) are specified using the "method" parameter. + All methods are based on estimates of the cross-spectral density (CSD). + + Parameters + ---------- + data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs + The data from which to compute connectivity. Note that it is also + possible to combine multiple signals by providing a list of tuples, + e.g., data = [(arr_0, stc_0), (arr_1, stc_1), (arr_2, stc_2)], + corresponds to 3 epochs, and arr_* could be an array with the same + number of time points as stc_*. The array-like object can also + be a list/generator of array, shape =(n_signals, n_times), + or a list/generator of SourceEstimate or VolSourceEstimate objects. + %(names)s + method : str | list of str + Connectivity measure(s) to compute. These can be ``['mic', 'mim', 'gc', + 'gc_tr']``. + indices : tuple of array | None + Two arrays with indices of connections for which to compute + connectivity. Each array for the seeds and targets should consist of + nested arrays containing the channel indices for each multivariate + connection. If ``None``, connections between all channels are computed, + unless a Granger causality method is called, in which case an error is + raised. + sfreq : float + The sampling frequency. Required if data is not + :class:`Epochs `. + mode : str + Spectrum estimation mode can be either: 'multitaper', 'fourier', or + 'cwt_morlet'. + fmin : float | tuple of float + The lower frequency of interest. Multiple bands are defined using + a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. + fmax : float | tuple of float + The upper frequency of interest. Multiple bands are dedined using + a tuple, e.g. (13., 30.) for two band with 13Hz and 30Hz upper freq. + fskip : int + Omit every "(fskip + 1)-th" frequency bin to decimate in frequency + domain. + faverage : bool + Average connectivity scores for each frequency band. If True, + the output freqs will be a list with arrays of the frequencies + that were averaged. + tmin : float | None + Time to start connectivity estimation. Note: when "data" is an array, + the first sample is assumed to be at time 0. For other types + (Epochs, etc.), the time information contained in the object is used + to compute the time indices. + tmax : float | None + Time to end connectivity estimation. Note: when "data" is an array, + the first sample is assumed to be at time 0. For other types + (Epochs, etc.), the time information contained in the object is used + to compute the time indices. + mt_bandwidth : float | None + The bandwidth of the multitaper windowing function in Hz. + Only used in 'multitaper' mode. + mt_adaptive : bool + Use adaptive weights to combine the tapered spectra into PSD. + Only used in 'multitaper' mode. + mt_low_bias : bool + Only use tapers with more than 90 percent spectral concentration + within bandwidth. Only used in 'multitaper' mode. + cwt_freqs : array + Array of frequencies of interest. Only used in 'cwt_morlet' mode. + cwt_n_cycles : float | array of float + Number of cycles. Fixed number or one per frequency. Only used in + 'cwt_morlet' mode. + gc_n_lags : int + Number of lags to use when computing Granger causality (the vector + autoregressive model order). Higher values increase computational cost, + but reduce the degree of spectral smoothing in the results. Must be < + (n_freqs - 1) * 2. Only used if ``method`` contains any of ``['gc', + 'gc_tr']``. + rank : tuple of array | None + Two arrays with the rank to project the seed and target data to, + respectively, using singular value decomposition. If None, the rank of + the data is computed and projected to. Only used if ``method`` contains + any of ``['mic', 'mim', 'gc', 'gc_tr']``. + block_size : int + How many CSD entries to compute at once (higher numbers are faster but + require more memory). + n_jobs : int + How many samples to process in parallel. + %(verbose)s + + Returns + ------- + con : array | list of array + Computed connectivity measure(s). Either an instance of + ``SpectralConnectivity`` or ``SpectroTemporalConnectivity``. + The shape of the connectivity result will be: + + - ``(n_cons, n_freqs)`` for multitaper or fourier modes + - ``(n_cons, n_freqs, n_times)`` for cwt_morlet mode + - ``n_cons = 1`` when ``indices=None`` + - ``n_cons = len(indices[0])`` when indices is supplied + + See Also + -------- + mne_connectivity.spectral_connectivity_epochs + mne_connectivity.spectral_connectivity_time + mne_connectivity.SpectralConnectivity + mne_connectivity.SpectroTemporalConnectivity + + Notes + ----- + Please note that the interpretation of the measures in this function + depends on the data and underlying assumptions and does not necessarily + reflect a causal relationship between brain regions. + + These measures are not to be interpreted over time. Each Epoch passed into + the dataset is interpreted as an independent sample of the same + connectivity structure. Within each Epoch, it is assumed that the spectral + measure is stationary. The spectral measures implemented in this function + are computed across Epochs. **Thus, spectral measures computed with only + one Epoch will result in errorful values and spectral measures computed + with few Epochs will be unreliable.** Please see + ``spectral_connectivity_time`` for time-resolved connectivity estimation. + + The spectral densities can be estimated using a multitaper method with + digital prolate spheroidal sequence (DPSS) windows, a discrete Fourier + transform with Hanning windows, or a continuous wavelet transform using + Morlet wavelets. The spectral estimation mode is specified using the + "mode" parameter. + + By default, "indices" is None, and the connectivity between all signals is + computed and a single connectivity spectrum will be returned (this is not + possible if a Granger causality method is called). If one is only + interested in the connectivity between some signals, the "indices" + parameter can be used. Seed and target indices for each connection should + be specified as nested array-likes. For example, to compute the + connectivity between signals (0, 1) -> (2, 3) and (0, 1) -> (4, 5), indices + should be specified as:: + + indices = ([[0, 1], [0, 1]], # seeds + [[2, 3], [4, 5]]) # targets + + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. + + **Supported Connectivity Measures** + + The connectivity method(s) is specified using the "method" parameter. + Multiple measures can be computed at once by using a list/tuple, e.g., + ``['mic', 'gc']``. The following methods are supported: + + 'mic' : Maximised Imaginary part of Coherency (MIC) + :footcite:`EwaldEtAl2012` given by: + + :math:`MIC=\Large{\frac{\boldsymbol{\alpha}^T \boldsymbol{E \beta}} + {\parallel\boldsymbol{\alpha}\parallel \parallel\boldsymbol{\beta} + \parallel}}` + + where: :math:`\boldsymbol{E}` is the imaginary part of the + transformed cross-spectral density between seeds and targets; and + :math:`\boldsymbol{\alpha}` and :math:`\boldsymbol{\beta}` are + eigenvectors for the seeds and targets, such that + :math:`\boldsymbol{\alpha}^T \boldsymbol{E \beta}` maximises + connectivity between the seeds and targets. + + 'mim' : Multivariate Interaction Measure (MIM) + :footcite:`EwaldEtAl2012` given by: + + :math:`MIM=tr(\boldsymbol{EE}^T)` + + 'gc' : State-space Granger Causality (GC) :footcite:`BarnettSeth2015` + given by: + + :math:`GC = ln\Large{(\frac{\lvert\boldsymbol{S}_{tt}\rvert}{\lvert + \boldsymbol{S}_{tt}-\boldsymbol{H}_{ts}\boldsymbol{\Sigma}_{ss + \lvert t}\boldsymbol{H}_{ts}^*\rvert}})`, + + where: :math:`s` and :math:`t` represent the seeds and targets, + respectively; :math:`\boldsymbol{H}` is the spectral transfer + function; :math:`\boldsymbol{\Sigma}` is the residuals matrix of + the autoregressive model; and :math:`\boldsymbol{S}` is + :math:`\boldsymbol{\Sigma}` transformed by :math:`\boldsymbol{H}`. + + 'gc_tr' : State-space GC on time-reversed signals + :footcite:`BarnettSeth2015,WinklerEtAl2016` given by the same equation + as for 'gc', but where the autocovariance sequence from which the + autoregressive model is produced is transposed to mimic the reversal of + the original signal in time. + + References + ---------- + .. footbibliography:: + """ + ( + fmin, fmax, n_bands, method, con_method_types, accumulate_psd, + parallel, my_epoch_spectral_connectivity + ) = _check_spectral_connectivity_epochs_settings( + method, fmin, fmax, n_jobs, verbose, _CON_METHOD_MAP) + + if n_bands != 1 and any( + this_method in _gc_methods for this_method in method + ): + raise ValueError('computing Granger causality on multiple frequency ' + 'bands is not yet supported') + + (names, times_in, sfreq, events, event_id, + metadata) = _check_spectral_connectivity_epochs_data(data, sfreq, names) + + # loop over data; it could be a generator that returns + # (n_signals x n_times) arrays or SourceEstimates + epoch_idx = 0 + logger.info('Connectivity computation...') + warn_times = True + for epoch_block in _get_n_epochs(data, n_jobs): + if epoch_idx == 0: + # initialize everything times and frequencies + (times, n_times, times_in, n_times_in, tmin_idx, tmax_idx, n_freqs, + freq_mask, freqs, freqs_bands, freq_idx_bands, n_signals, + warn_times) = _prepare_connectivity( + epoch_block=epoch_block, times_in=times_in, tmin=tmin, + tmax=tmax, fmin=fmin, fmax=fmax, sfreq=sfreq, mode=mode, + fskip=fskip, n_bands=n_bands, cwt_freqs=cwt_freqs, + faverage=faverage) + + # check indices input + n_cons, indices_use = _check_indices(indices, method, n_signals) + + # check rank input and compute data ranks + rank = _check_rank_input(rank, data, indices_use) + + # make sure padded indices are stored in the connectivity object + if indices is not None: + indices = tuple(np.array(indices_use)) # create a copy + + # get the window function, wavelets, etc for different modes + (spectral_params, mt_adaptive, n_times_spectrum, + n_tapers) = _assemble_spectral_params( + mode=mode, n_times=n_times, mt_adaptive=mt_adaptive, + mt_bandwidth=mt_bandwidth, sfreq=sfreq, + mt_low_bias=mt_low_bias, cwt_n_cycles=cwt_n_cycles, + cwt_freqs=cwt_freqs, freqs=freqs, freq_mask=freq_mask) + + # unique signals for which we actually need to compute CSD + sig_idx = np.unique(np.concatenate(np.concatenate( + indices_use))) + sig_idx = sig_idx[sig_idx != -1] + remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(sig_idx)} + remapping[-1] = -1 + remapped_inds = (indices_use[0].copy(), indices_use[1].copy()) + con_i = 0 + for seed, target in zip(indices_use[0], indices_use[1]): + remapped_inds[0][con_i] = np.array([ + remapping[idx] for idx in seed]) + remapped_inds[1][con_i] = np.array([ + remapping[idx] for idx in target]) + con_i += 1 + remapped_sig = [remapping[idx] for idx in sig_idx] + n_signals_use = len(sig_idx) + + # map indices to unique indices + indices_use = remapped_inds # use remapped seeds & targets + idx_map = [np.sort(np.repeat(remapped_sig, len(sig_idx))), + np.tile(remapped_sig, len(sig_idx))] + + # create instances of the connectivity estimators + con_methods = [] + for mtype_i, mtype in enumerate(con_method_types): + method_params = dict(n_cons=n_cons, n_freqs=n_freqs, + n_times=n_times_spectrum, + n_signals=n_signals_use) + if method[mtype_i] in _gc_methods: + method_params.update(dict(n_lags=gc_n_lags)) + con_methods.append(mtype(**method_params)) + + sep = ', ' + metrics_str = sep.join([meth.name for meth in con_methods]) + logger.info(' the following metrics will be computed: %s' + % metrics_str) + + call_params = dict( + sig_idx=sig_idx, tmin_idx=tmin_idx, tmax_idx=tmax_idx, sfreq=sfreq, + method=method, mode=mode, freq_mask=freq_mask, idx_map=idx_map, + n_cons=n_cons, block_size=block_size, + psd=None, accumulate_psd=accumulate_psd, + mt_adaptive=mt_adaptive, + con_method_types=con_method_types, + con_methods=con_methods if n_jobs == 1 else None, + n_signals=n_signals, n_signals_use=n_signals_use, n_times=n_times, + gc_n_lags=gc_n_lags, multivariate_con=True, + accumulate_inplace=True if n_jobs == 1 else False) + call_params.update(**spectral_params) + + epoch_idx = _compute_spectral_methods_epochs( + con_methods, epoch_block, epoch_idx, call_params, parallel, + my_epoch_spectral_connectivity, n_jobs, n_times_in, times_in, + warn_times) + n_epochs = epoch_idx + + # compute final connectivity scores + con = list() + patterns = list() + for conn_method in con_methods: + + # compute connectivity scores + conn_method.compute_con(indices_use, rank, n_epochs) + + # get the connectivity scores + this_con = conn_method.con_scores + this_patterns = conn_method.patterns + + if this_con.shape[0] != n_cons: + raise RuntimeError( + 'first dimension of connectivity scores does not match the ' + 'number of connections; please contact the mne-connectivity ' + 'developers') + if faverage: + if this_con.shape[1] != n_freqs: + raise RuntimeError( + 'second dimension of connectivity scores does not match ' + 'the number of frequencies; please contact the ' + 'mne-connectivity developers') + con_shape = (n_cons, n_bands) + this_con.shape[2:] + this_con_bands = np.empty(con_shape, dtype=this_con.dtype) + for band_idx in range(n_bands): + this_con_bands[:, band_idx] = np.mean( + this_con[:, freq_idx_bands[band_idx]], axis=1) + this_con = this_con_bands + + if this_patterns is not None: + patterns_shape = list(this_patterns.shape) + patterns_shape[3] = n_bands + this_patterns_bands = np.empty(patterns_shape, + dtype=this_patterns.dtype) + for band_idx in range(n_bands): + this_patterns_bands[:, :, :, band_idx] = np.mean( + this_patterns[:, :, :, freq_idx_bands[band_idx]], + axis=3) + this_patterns = this_patterns_bands + + con.append(this_con) + patterns.append(this_patterns) + + conn_list = _store_results( + con=con, patterns=patterns, method=method, freqs=freqs, + faverage=faverage, freqs_bands=freqs_bands, names=names, mode=mode, + indices=indices, n_epochs=n_epochs, times=times, n_tapers=n_tapers, + metadata=metadata, events=events, event_id=event_id, rank=rank, + gc_n_lags=gc_n_lags, n_signals=n_signals) + + return conn_list diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 592291f0..54bfafa5 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -9,10 +9,11 @@ from mne_connectivity import ( SpectralConnectivity, spectral_connectivity_epochs, + spectral_connectivity_epochs_multivariate, read_connectivity, spectral_connectivity_time) -from mne_connectivity.spectral.epochs import _CohEst, _get_n_epochs from mne_connectivity.spectral.epochs import ( - _compute_freq_mask, _compute_freqs) + _get_n_epochs, _compute_freq_mask, _compute_freqs) +from mne_connectivity.spectral.epochs_bivariate import _CohEst def create_test_dataset(sfreq, n_signals, n_epochs, n_times, tmin, tmax, @@ -447,7 +448,7 @@ def test_spectral_connectivity_epochs_multivariate(method): data = data.reshape(n_signals, n_epochs, n_times) data = np.transpose(data, [1, 0, 2]) - con = spectral_connectivity_epochs( + con = spectral_connectivity_epochs_multivariate( data, method=method, mode=mode, indices=indices, sfreq=sfreq, gc_n_lags=20) freqs = con.freqs @@ -473,17 +474,17 @@ def test_spectral_connectivity_epochs_multivariate(method): # check that target -> seed connectivity is low indices_ts = (indices[1], indices[0]) - con_ts = spectral_connectivity_epochs( + con_ts = spectral_connectivity_epochs_multivariate( data, method=method, mode=mode, indices=indices_ts, sfreq=sfreq, gc_n_lags=20) assert con_ts.get_data()[0, gidx[0]:gidx[1]].mean() < lower_t # check that TRGC is positive (i.e. net seed -> target connectivity not # due to noise) - con_tr = spectral_connectivity_epochs( + con_tr = spectral_connectivity_epochs_multivariate( data, method='gc_tr', mode=mode, indices=indices, sfreq=sfreq, gc_n_lags=20) - con_ts_tr = spectral_connectivity_epochs( + con_ts_tr = spectral_connectivity_epochs_multivariate( data, method='gc_tr', mode=mode, indices=indices_ts, sfreq=sfreq, gc_n_lags=20) trgc = ((con.get_data() - con_ts.get_data()) - @@ -497,7 +498,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check all-to-all conn. computed for MIC/MIM when no indices given if method in ['mic', 'mim']: - con = spectral_connectivity_epochs( + con = spectral_connectivity_epochs_multivariate( data, method=method, mode=mode, indices=None, sfreq=sfreq) assert con.indices is None assert con.n_nodes == n_signals @@ -506,7 +507,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check ragged indices padded correctly ragged_indices = (np.array([[0]]), np.array([[1, 2]])) - con = spectral_connectivity_epochs( + con = spectral_connectivity_epochs_multivariate( data, method=method, mode=mode, indices=ragged_indices, sfreq=sfreq) assert np.all(np.array(con.indices) == np.array([np.array([[0, -1]]), np.array([[1, 2]])])) @@ -514,7 +515,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check shape of MIC patterns if method == 'mic': for mode in ['multitaper', 'cwt_morlet']: - con = spectral_connectivity_epochs( + con = spectral_connectivity_epochs_multivariate( data, method=method, mode=mode, indices=indices, sfreq=sfreq, fmin=10, fmax=25, cwt_freqs=np.arange(10, 25), faverage=True) @@ -535,7 +536,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check patterns averaged over freqs fmin = (5., 15.) fmax = (15., 30.) - con = spectral_connectivity_epochs( + con = spectral_connectivity_epochs_multivariate( data, method=method, mode=mode, indices=indices, sfreq=sfreq, fmin=fmin, fmax=fmax, faverage=True) assert np.shape(con.attrs["patterns"][0][0])[1] == len(fmin) @@ -543,7 +544,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check patterns shape matches input data, not rank rank = (np.array([1]), np.array([1])) - con = spectral_connectivity_epochs( + con = spectral_connectivity_epochs_multivariate( data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=rank) assert (np.shape(con.attrs["patterns"][0][0])[0] == n_seeds) @@ -551,7 +552,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check patterns padded correctly ragged_indices = (np.array([[0]]), np.array([[1, 2]])) - con = spectral_connectivity_epochs( + con = spectral_connectivity_epochs_multivariate( data, method=method, mode=mode, indices=ragged_indices, sfreq=sfreq) patterns = np.array(con.attrs["patterns"]) @@ -586,7 +587,7 @@ def test_multivariate_spectral_connectivity_epochs_regression(): sfreq = 100 indices = (np.array([[0, 1]]), np.array([[2, 3]])) methods = ['mic', 'mim', 'gc', 'gc_tr'] - con = spectral_connectivity_epochs( + con = spectral_connectivity_epochs_multivariate( data, method=methods, indices=indices, mode='multitaper', sfreq=sfreq, fskip=0, faverage=False, tmin=0, tmax=None, mt_bandwidth=4, mt_low_bias=True, mt_adaptive=False, gc_n_lags=20, @@ -594,8 +595,9 @@ def test_multivariate_spectral_connectivity_epochs_regression(): # should take the absolute of the MIC scores, as the MATLAB implementation # returns the absolute values. - mne_results = {this_con.method: np.abs(this_con.get_data()) - for this_con in con} + mne_results = {this_con.method: this_con.get_data() for this_con in con} + mne_results["mic"] = np.abs(mne_results["mic"]) + matlab_results = pd.read_pickle( os.path.join(fpath, 'data', 'example_multivariate_matlab_results.pkl')) for method in methods: @@ -620,40 +622,29 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): with pytest.raises(TypeError, match='multivariate indices must contain array-likes'): non_nested_indices = (np.array([0, 1]), np.array([2, 3])) - spectral_connectivity_epochs( + spectral_connectivity_epochs_multivariate( data, method=method, mode=mode, indices=non_nested_indices, - sfreq=sfreq, gc_n_lags=10) + sfreq=sfreq, cwt_freqs=cwt_freqs, gc_n_lags=10) # check bad indices with repeated channels caught with pytest.raises(ValueError, match='multivariate indices cannot contain repeated'): repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) - spectral_connectivity_epochs( + spectral_connectivity_epochs_multivariate( data, method=method, mode=mode, indices=repeated_indices, - sfreq=sfreq, gc_n_lags=10) - - # check mixed methods caught - with pytest.raises(ValueError, - match='bivariate and multivariate connectivity'): - if isinstance(method, str): - mixed_methods = [method, 'coh'] - elif isinstance(method, list): - mixed_methods = [*method, 'coh'] - spectral_connectivity_epochs(data, method=mixed_methods, mode=mode, - indices=indices, sfreq=sfreq, - cwt_freqs=cwt_freqs) + sfreq=sfreq, cwt_freqs=cwt_freqs, gc_n_lags=10) # check bad rank args caught too_low_rank = (np.array([0]), np.array([0])) with pytest.raises(ValueError, match='ranks for seeds and targets must be'): - spectral_connectivity_epochs( + spectral_connectivity_epochs_multivariate( data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=too_low_rank, cwt_freqs=cwt_freqs) too_high_rank = (np.array([3]), np.array([3])) with pytest.raises(ValueError, match='ranks for seeds and targets must be'): - spectral_connectivity_epochs( + spectral_connectivity_epochs_multivariate( data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=too_high_rank, cwt_freqs=cwt_freqs) @@ -664,7 +655,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): assert np.all(np.linalg.matrix_rank(bad_data[:, (0, 1), :]) == 1) assert np.all(np.linalg.matrix_rank(bad_data[:, (2, 3), :]) == 1) if isinstance(method, str): - rank_con = spectral_connectivity_epochs( + rank_con = spectral_connectivity_epochs_multivariate( bad_data, method=method, mode=mode, indices=indices, sfreq=sfreq, gc_n_lags=10, cwt_freqs=cwt_freqs) assert rank_con.attrs["rank"] == (np.array([1]), np.array([1])) @@ -673,7 +664,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): # check rank-deficient transformation matrix caught with pytest.raises(RuntimeError, match='the transformation matrix'): - spectral_connectivity_epochs( + spectral_connectivity_epochs_multivariate( bad_data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=(np.array([2]), np.array([2])), cwt_freqs=cwt_freqs) @@ -684,37 +675,36 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): frange = (5, 10) n_lags = 200 # will be far too high with pytest.raises(ValueError, match='the number of lags'): - spectral_connectivity_epochs( + spectral_connectivity_epochs_multivariate( data, method=method, mode=mode, indices=indices, sfreq=sfreq, fmin=frange[0], fmax=frange[1], gc_n_lags=n_lags, cwt_freqs=cwt_freqs) # check no indices caught with pytest.raises(ValueError, match='indices must be specified'): - spectral_connectivity_epochs(data, method=method, mode=mode, - indices=None, sfreq=sfreq, - cwt_freqs=cwt_freqs) + spectral_connectivity_epochs_multivariate( + data, method=method, mode=mode, indices=None, sfreq=sfreq, + cwt_freqs=cwt_freqs) # check intersecting indices caught bad_indices = (np.array([[0, 1]]), np.array([[0, 2]])) with pytest.raises(ValueError, match='seed and target indices must not intersect'): - spectral_connectivity_epochs(data, method=method, mode=mode, - indices=bad_indices, sfreq=sfreq, - cwt_freqs=cwt_freqs) + spectral_connectivity_epochs_multivariate( + data, method=method, mode=mode, indices=bad_indices, + sfreq=sfreq, cwt_freqs=cwt_freqs) # check bad fmin/fmax caught with pytest.raises(ValueError, match='computing Granger causality on multiple'): - spectral_connectivity_epochs(data, method=method, mode=mode, - indices=indices, sfreq=sfreq, - fmin=(10., 15.), fmax=(15., 20.), - cwt_freqs=cwt_freqs) + spectral_connectivity_epochs_multivariate( + data, method=method, mode=mode, indices=indices, sfreq=sfreq, + fmin=(10., 15.), fmax=(15., 20.), cwt_freqs=cwt_freqs) # check rank-deficient autocovariance caught with pytest.raises(RuntimeError, match='the autocovariance matrix is singular'): - spectral_connectivity_epochs( + spectral_connectivity_epochs_multivariate( bad_data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=(np.array([2]), np.array([2])), cwt_freqs=cwt_freqs) @@ -731,7 +721,7 @@ def test_multivar_spectral_connectivity_parallel(method): data = rng.randn(n_epochs, n_signals, n_times) indices = (np.array([[0, 1]]), np.array([[2, 3]])) - spectral_connectivity_epochs( + spectral_connectivity_epochs_multivariate( data, method=method, mode="multitaper", indices=indices, sfreq=sfreq, gc_n_lags=10, n_jobs=2) spectral_connectivity_time( @@ -761,12 +751,13 @@ def test_multivar_spectral_connectivity_flipped_indices(): # we test on GC since this is a directed connectivity measure method = 'gc' - con_st = spectral_connectivity_epochs( # seed -> target + con_st = spectral_connectivity_epochs_multivariate( # seed -> target data, method=method, indices=indices, sfreq=sfreq, gc_n_lags=10) - con_ts = spectral_connectivity_epochs( # target -> seed + con_ts = spectral_connectivity_epochs_multivariate( # target -> seed data, method=method, indices=flipped_indices, sfreq=sfreq, gc_n_lags=10) - con_st_ts = spectral_connectivity_epochs( # seed -> target; target -> seed + con_st_ts = spectral_connectivity_epochs_multivariate( + # seed -> target; target -> seed data, method=method, indices=concat_indices, sfreq=sfreq, gc_n_lags=10) assert not np.all(con_st.get_data() == con_ts.get_data()) assert np.all(con_st.get_data()[0] == con_st_ts.get_data()[0]) @@ -1298,7 +1289,7 @@ def test_multivar_save_load(tmp_path): non_ragged_indices = (np.array([[0, 1]]), np.array([[2, 3]])) ragged_indices = (np.array([[0, 1]]), np.array([[2]])) for indices in [non_ragged_indices, ragged_indices]: - con = spectral_connectivity_epochs( + con = spectral_connectivity_epochs_multivariate( epochs, method=['mic', 'mim', 'gc', 'gc_tr'], indices=indices, sfreq=sfreq, fmin=10, fmax=30) for this_con in con: @@ -1315,12 +1306,9 @@ def test_multivar_save_load(tmp_path): assert a == b -@pytest.mark.parametrize("method", ["coh", "plv", "pli", "wpli", "ciplv", - "mic", "mim"]) +@pytest.mark.parametrize("method", ['coh', 'plv', 'pli', 'wpli', 'ciplv']) @pytest.mark.parametrize("indices", [None, - (np.array([0, 1]), np.array([2, 3])), - (np.array([[0, 1]]), np.array([[2, 3]])) - ]) + (np.array([0, 1]), np.array([2, 3]))]) def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): """Test that indices values and type is maintained after saving. @@ -1337,14 +1325,6 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): freqs = np.arange(10, 31) tmp_file = os.path.join(tmp_path, "foo_mvc.nc") - # mutlivariate and bivariate methods require the right indices shape - if method in ["mic", "mim"]: - if indices is not None and indices[0].ndim == 1: - pytest.skip() - else: - if indices is not None and indices[0].ndim == 2: - pytest.skip() - # test the pair of method and indices defined to check the output indices con_epochs = spectral_connectivity_epochs( epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30 @@ -1366,3 +1346,53 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): assert np.all(np.array(con.indices) == np.array(read_con.indices)) else: assert con.indices is None and read_con.indices is None + + +@pytest.mark.parametrize("method", ['mic', 'mim', 'gc', 'gc_tr']) +@pytest.mark.parametrize("indices", [None, + (np.array([[0, 1]]), np.array([[2, 3]]))]) +def test_multivar_spectral_connectivity_indices_roundtrip_io( + tmp_path, method, indices +): + """Test that indices values and type is maintained after saving. + + If `indices` is None, `indices` in the returned connectivity object should + be None, otherwise, `indices` should be a tuple. The type of `indices` and + its values should be retained after saving and reloading. + """ + rng = np.random.RandomState(0) + n_epochs, n_chs, n_times, sfreq = 5, 4, 200, 100.0 + data = rng.randn(n_epochs, n_chs, n_times) + info = create_info(n_chs, sfreq, "eeg") + tmin = -1 + epochs = EpochsArray(data, info, tmin=tmin) + freqs = np.arange(10, 31) + tmp_file = os.path.join(tmp_path, "foo_mvc.nc") + + # test the pair of method and indices defined to check the output indices + if indices is None and method in ['gc', 'gc_tr']: + # indicesmust be specified for GC + pytest.skip() + + con_epochs = spectral_connectivity_epochs_multivariate( + epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30, + gc_n_lags=10 + ) + con_time = spectral_connectivity_time( + epochs, freqs, method=method, indices=indices, sfreq=sfreq, + gc_n_lags=10 + ) + + for con in [con_epochs, con_time]: + con.save(tmp_file) + read_con = read_connectivity(tmp_file) + + if indices is not None: + # check indices of same type (tuples) + assert isinstance(con.indices, tuple) and isinstance( + read_con.indices, tuple + ) + # check indices have same values + assert np.all(np.array(con.indices) == np.array(read_con.indices)) + else: + assert con.indices is None and read_con.indices is None diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 3798f699..d0059ace 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -13,8 +13,9 @@ from mne.utils import (logger, verbose) from ..base import (SpectralConnectivity, EpochSpectralConnectivity) -from .epochs import (_MICEst, _MIMEst, _GCEst, _GCTREst, _compute_freq_mask, - _check_rank_input) +from .epochs import _compute_freq_mask +from .epochs_multivariate import (_MICEst, _MIMEst, _GCEst, _GCTREst, + _check_rank_input) from .smooth import _create_kernel, _smooth_spectra from ..utils import check_indices, check_multivariate_indices, fill_doc From 77df347d9a8a1d3e0c510a4f3e5f9cb5ac8aacb2 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 24 Oct 2023 17:40:20 +0200 Subject: [PATCH 22/40] try fix ci error --- ignore_words.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ignore_words.txt b/ignore_words.txt index 1375fa21..6dc3af46 100644 --- a/ignore_words.txt +++ b/ignore_words.txt @@ -1,4 +1,5 @@ nd adn ba -BA \ No newline at end of file +BA +Manuel \ No newline at end of file From c284589bbbc9424b7c69b7b20b44e900af8e4746 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 24 Oct 2023 23:18:27 +0200 Subject: [PATCH 23/40] bug fix missing refactoring for example --- examples/handling_ragged_arrays.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/handling_ragged_arrays.py b/examples/handling_ragged_arrays.py index 67f06a4e..2bc0dcc0 100644 --- a/examples/handling_ragged_arrays.py +++ b/examples/handling_ragged_arrays.py @@ -14,7 +14,7 @@ import numpy as np -from mne_connectivity import spectral_connectivity_epochs +from mne_connectivity import spectral_connectivity_epochs_multivariate ############################################################################### # Background @@ -44,7 +44,7 @@ # targets = [[2, 3, 4], [4 ]] # # The ``indices`` parameter passed to -# :func:`~mne_connectivity.spectral_connectivity_epochs` and +# :func:`~mne_connectivity.spectral_connectivity_epochs_multivariate` and # :func:`~mne_connectivity.spectral_connectivity_time` must be a tuple of # array-likes, meaning # that the indices can be passed as a tuple of: lists; tuples; or NumPy arrays. @@ -108,7 +108,7 @@ [[2, 3, 4], [4]]) # targets # compute connectivity -con = spectral_connectivity_epochs( +con = spectral_connectivity_epochs_multivariate( data, method='mic', indices=ragged_indices, sfreq=sfreq, fmin=10, fmax=30, verbose=False) patterns = np.array(con.attrs['patterns']) From 0c54943c3178543f31cda149bc3710f064e94d25 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 26 Oct 2023 22:47:59 +0200 Subject: [PATCH 24/40] switch to masked arrays for indices --- examples/handling_ragged_arrays.py | 39 +++--- mne_connectivity/io.py | 6 +- .../spectral/epochs_multivariate.py | 53 ++++---- .../spectral/tests/test_spectral.py | 6 +- mne_connectivity/spectral/time.py | 42 +++--- mne_connectivity/tests/test_utils.py | 16 ++- mne_connectivity/utils/utils.py | 125 ++++++++++-------- 7 files changed, 154 insertions(+), 133 deletions(-) diff --git a/examples/handling_ragged_arrays.py b/examples/handling_ragged_arrays.py index 2bc0dcc0..ecb4a9c4 100644 --- a/examples/handling_ragged_arrays.py +++ b/examples/handling_ragged_arrays.py @@ -62,8 +62,8 @@ # ragged_indices = (np.array([[0, 1 ], [0, 1, 2, 3]], dtype='object'), # np.array([[2, 3, 4], [4 ]], dtype='object')) # -# **N.B. Note that when forming ragged arrays in NumPy, dtype='object' must be -# specified.** +# **N.B. Note that since NumPy v1.19.0, dtype='object' must be specified when +# forming ragged arrays.** # # Just as for bivariate connectivity, the length of ``indices[0]`` and # ``indices[1]`` is equal (i.e. the number of connections), however information @@ -71,19 +71,16 @@ # array. Importantly, these indices are ragged, as the first connection will be # computed between 2 seed and 3 target channels, and the second connection # between 4 seed and 1 target channel. The connectivity functions will -# recognise the indices as being ragged, and pad them accordingly to make them -# easier to work with and compatible with the h5netcdf saving engine. The known -# value used to pad the arrays is ``-1``, an invalid channel index. The above -# indices would be padded to:: +# recognise the indices as being ragged, and pad them to a 'full' array by +# adding placeholder values which are masked accordingly. This makes the +# indices easier to work with, and also compatible with the engine used to save +# connectivity objects. For example, the above indices would become:: # -# padded_indices = (np.array([[0, 1, -1, -1], [0, 1, 2, 3]]), -# np.array([[2, 3, 4, -1], [4, -1, -1, -1]])) +# padded_indices = (np.array([[0, 1, --, --], [0, 1, 2, 3]]), +# np.array([[2, 3, 4, --], [4, --, --, --]])) # -# These indices are what is stored in the connectivity object, and is also the -# format of indices returned from the helper functions -# :func:`~mne_connectivity.check_multivariate_indices` and -# :func:`~mne_connectivity.seed_target_multivariate_indices`. It is also -# possible to pass the padded indices to the connectivity functions directly. +# where ``--`` are masked entries. These indices are what is stored in the +# returned connectivity objects. # # For the connectivity results themselves, the methods available in # MNE-Connectivity combine information across the different channels into a @@ -118,11 +115,11 @@ max_n_chans = max( [len(inds) for inds in ([*ragged_indices[0], *ragged_indices[1]])]) -# show that the padded indices entries are all -1 -assert np.count_nonzero(padded_indices[0][0] == -1) == 2 # 2 padded channels -assert np.count_nonzero(padded_indices[1][0] == -1) == 1 # 1 padded channels -assert np.count_nonzero(padded_indices[0][1] == -1) == 0 # 0 padded channels -assert np.count_nonzero(padded_indices[1][1] == -1) == 3 # 3 padded channels +# show that the padded indices entries are masked +assert np.sum(padded_indices[0][0].mask) == 2 # 2 padded channels +assert np.sum(padded_indices[1][0].mask) == 1 # 1 padded channels +assert np.sum(padded_indices[0][1].mask) == 0 # 0 padded channels +assert np.sum(padded_indices[1][1].mask) == 3 # 3 padded channels # patterns have shape [seeds/targets x cons x max channels x freqs (x times)] assert patterns.shape == (2, n_cons, max_n_chans, n_freqs) @@ -137,11 +134,11 @@ seed_patterns_con1 = patterns[0, 0, :len(ragged_indices[0][0])] target_patterns_con1 = patterns[1, 0, :len(ragged_indices[1][0])] -# extract patterns for second connection using the padded indices (pad = -1) +# extract patterns for second connection using the padded, masked indices seed_patterns_con2 = ( - patterns[0, 1, :np.count_nonzero(padded_indices[0][1] != -1)]) + patterns[0, 1, :padded_indices[0][1].count()]) target_patterns_con2 = ( - patterns[1, 1, :np.count_nonzero(padded_indices[1][1] != -1)]) + patterns[1, 1, :padded_indices[1][1].count()]) # show that shapes of patterns are correct assert seed_patterns_con1.shape == (2, n_freqs) # channels (0, 1) diff --git a/mne_connectivity/io.py b/mne_connectivity/io.py index 63aa3501..e8d9b916 100644 --- a/mne_connectivity/io.py +++ b/mne_connectivity/io.py @@ -53,9 +53,11 @@ def _xarray_to_conn(array, cls_func): event_id = dict(zip(event_id_keys, event_id_vals)) array.attrs['event_id'] = event_id - # convert indices numpy arrays to a tuple of arrays + # convert indices numpy arrays to a tuple of masked arrays + # (only multivariate connectivity indices saved as arrays) if isinstance(array.attrs['indices'], np.ndarray): - array.attrs['indices'] = tuple(array.attrs['indices']) + array.attrs['indices'] = tuple( + np.ma.masked_values(array.attrs['indices'], -1)) # create the connectivity class conn = cls_func( diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index 28077adb..db648456 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -31,12 +31,18 @@ def _check_indices(indices, method, n_signals): logger.info('using all indices for multivariate connectivity') indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], np.arange(n_signals, dtype=int)[np.newaxis, :]) + indices_use = np.ma.masked_array(indices_use, + mask=False, fill_value=-1) else: - indices_use = check_multivariate_indices(indices) # pad with -1 + indices_use = check_multivariate_indices(indices) # mask indices + indices_use = np.ma.concatenate([inds[np.newaxis] for inds in + indices_use]) + np.ma.set_fill_value(indices_use, -1) # else 99999 after concat. if any(this_method in _gc_methods for this_method in method): - for seed, target in zip(indices[0], indices[1]): - intersection = np.intersect1d(seed, target) - if np.any(intersection != -1): # ignore padded entries + for seed, target in zip(indices_use[0], indices_use[1]): + intersection = np.intersect1d(seed.compressed(), + target.compressed()) + if intersection.size > 0: raise ValueError( 'seed and target indices must not intersect when ' 'computing Granger causality') @@ -59,16 +65,10 @@ def _check_rank_input(rank, data, indices): else: data_arr = data - # XXX: Unpadding of arrays after already padding them is perhaps not so - # efficient. However, we need to remove the padded values to - # ensure only the correct channels are indexed, and having two - # versions of indices is a bit messy currently. A candidate for - # refactoring to simplify code. - for group_i in range(2): # seeds and targets for con_i, con_idcs in enumerate(indices[group_i]): - con_idcs = con_idcs[con_idcs != -1] # -1 is padded value - s = np.linalg.svd(data_arr[:, con_idcs], compute_uv=False) + s = np.linalg.svd(data_arr[:, con_idcs.compressed()], + compute_uv=False) rank[group_i][con_i] = np.min( [np.count_nonzero(epoch >= epoch[0] * sv_tol) for epoch in s]) @@ -197,8 +197,8 @@ def compute_con(self, indices, ranks, n_epochs=1): indices[0], indices[1], ranks[0], ranks[1]): self._log_connection_number(con_i) - seed_idcs = seed_idcs[seed_idcs != -1] - target_idcs = target_idcs[target_idcs != -1] + seed_idcs = seed_idcs.compressed() + target_idcs = target_idcs.compressed() con_idcs = [*seed_idcs, *target_idcs] C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] @@ -432,8 +432,8 @@ def compute_con(self, indices, ranks, n_epochs=1): indices[0], indices[1], ranks[0], ranks[1]): self._log_connection_number(con_i) - seed_idcs = seed_idcs[seed_idcs != -1] - target_idcs = target_idcs[target_idcs != -1] + seed_idcs = seed_idcs.compressed() + target_idcs = target_idcs.compressed() con_idcs = [*seed_idcs, *target_idcs] C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] @@ -1009,7 +1009,8 @@ def spectral_connectivity_epochs_multivariate( # make sure padded indices are stored in the connectivity object if indices is not None: - indices = tuple(np.array(indices_use)) # create a copy + # create a copy + indices = (indices_use[0].copy(), indices_use[1].copy()) # get the window function, wavelets, etc for different modes (spectral_params, mt_adaptive, n_times_spectrum, @@ -1020,20 +1021,12 @@ def spectral_connectivity_epochs_multivariate( cwt_freqs=cwt_freqs, freqs=freqs, freq_mask=freq_mask) # unique signals for which we actually need to compute CSD - sig_idx = np.unique(np.concatenate(np.concatenate( - indices_use))) - sig_idx = sig_idx[sig_idx != -1] + sig_idx = np.unique(indices_use.compressed()) remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(sig_idx)} - remapping[-1] = -1 - remapped_inds = (indices_use[0].copy(), indices_use[1].copy()) - con_i = 0 - for seed, target in zip(indices_use[0], indices_use[1]): - remapped_inds[0][con_i] = np.array([ - remapping[idx] for idx in seed]) - remapped_inds[1][con_i] = np.array([ - remapping[idx] for idx in target]) - con_i += 1 - remapped_sig = [remapping[idx] for idx in sig_idx] + remapped_inds = indices_use.copy() + for idx in sig_idx: + remapped_inds[indices_use == idx] = remapping[idx] + remapped_sig = np.unique(remapped_inds.compressed()) n_signals_use = len(sig_idx) # map indices to unique indices diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 54bfafa5..d05ec61e 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -1392,7 +1392,11 @@ def test_multivar_spectral_connectivity_indices_roundtrip_io( assert isinstance(con.indices, tuple) and isinstance( read_con.indices, tuple ) + # check indices are masked + assert all([np.ma.isMA(inds) for inds in con.indices] and + [np.ma.isMA(inds) for inds in read_con.indices]) # check indices have same values - assert np.all(np.array(con.indices) == np.array(read_con.indices)) + assert np.all([con_inds == read_inds for con_inds, read_inds in + zip(con.indices, read_con.indices)]) else: assert con.indices is None and read_con.indices is None diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index d0059ace..d44f5be3 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -408,46 +408,50 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, 'indices must be specified when computing Granger ' 'causality, as all-to-all connectivity is not supported') logger.info('using all indices for multivariate connectivity') - indices_use = (np.array([np.arange(n_signals, dtype=np.int32)]), - np.array([np.arange(n_signals, dtype=np.int32)])) + indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], + np.arange(n_signals, dtype=int)[np.newaxis, :]) + indices_use = np.ma.masked_array(indices_use, + mask=False, fill_value=-1) else: logger.info('only using indices for lower-triangular matrix') indices_use = np.tril_indices(n_signals, k=-1) else: if multivariate_con: - indices_use = check_multivariate_indices(indices) # pad with -1 + indices_use = check_multivariate_indices(indices) # mask indices + indices_use = np.ma.concatenate([inds[np.newaxis] for inds in + indices_use]) + np.ma.set_fill_value(indices_use, -1) # else 99999 after concat. if any(this_method in _gc_methods for this_method in method): - for seed, target in zip(indices[0], indices[1]): - intersection = np.intersect1d(seed, target) - if np.any(intersection != -1): # ignore padded entries + for seed, target in zip(indices_use[0], indices_use[1]): + intersection = np.intersect1d(seed.compressed(), + target.compressed()) + if intersection.size > 0: raise ValueError( 'seed and target indices must not intersect when ' 'computing Granger causality') # make sure padded indices are stored in the connectivity object - indices = tuple(np.array(indices_use)) # create a copy + # create a copy + indices = (indices_use[0].copy(), indices_use[1].copy()) else: indices_use = check_indices(indices) - # create copies of indices_use for independent manipulation - source_idx = np.array(indices_use[0]) - target_idx = np.array(indices_use[1]) - n_cons = len(source_idx) + n_cons = len(indices_use[0]) # unique signals for which we actually need to compute the CSD of if multivariate_con: - signals_use = np.unique(np.concatenate(np.concatenate(indices_use))) - signals_use = signals_use[signals_use != -1] + signals_use = np.unique(indices_use.compressed()) remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(signals_use)} - remapping[-1] = -1 + remapped_inds = indices_use.copy() # multivariate functions expect seed/target remapping - con_i = 0 - for seed, target in zip(indices_use[0], indices_use[1]): - source_idx[con_i] = np.array([remapping[idx] for idx in seed]) - target_idx[con_i] = np.array([remapping[idx] for idx in target]) - con_i += 1 + for idx in signals_use: + remapped_inds[indices_use == idx] = remapping[idx] + source_idx = remapped_inds[0] + target_idx = remapped_inds[1] max_n_channels = len(indices_use[0][0]) else: # no indices remapping required for bivariate functions signals_use = np.unique(np.r_[indices_use[0], indices_use[1]]) + source_idx = indices_use[0].copy() + target_idx = indices_use[1].copy() max_n_channels = len(indices_use[0]) # check rank input and compute data ranks if necessary diff --git a/mne_connectivity/tests/test_utils.py b/mne_connectivity/tests/test_utils.py index 1e5822eb..31ed3c79 100644 --- a/mne_connectivity/tests/test_utils.py +++ b/mne_connectivity/tests/test_utils.py @@ -34,14 +34,22 @@ def test_seed_target_indices(): seeds = [[0, 1]] targets = [[2, 3], [3, 4]] indices = seed_target_multivariate_indices(seeds, targets) - assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]), - np.array([[2, 3], [3, 4]]))) + match_indices = (np.array([[0, 1], [0, 1]], dtype=object), + np.array([[2, 3], [3, 4]], dtype=object)) + for type_i in range(2): + for con_i in range(len(indices[0])): + assert np.all(indices[type_i][con_i] == + match_indices[type_i][con_i]) # ragged indices seeds = [[0, 1]] targets = [[2, 3, 4], [4]] indices = seed_target_multivariate_indices(seeds, targets) - assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]), - np.array([[2, 3, 4], [4, -1, -1]]))) + match_indices = (np.array([[0, 1], [0, 1]], dtype=object), + np.array([[2, 3, 4], [4]], dtype=object)) + for type_i in range(2): + for con_i in range(len(indices[0])): + assert np.all(indices[type_i][con_i] == + match_indices[type_i][con_i]) # test error catching # non-array-like seeds/targets with pytest.raises(TypeError, diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index a0c69bd9..5b5f634b 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -83,18 +83,23 @@ def check_indices(indices): return indices -def check_multivariate_indices(indices): - """Check indices parameter for multivariate connectivity and pad it. +def check_multivariate_indices(indices, n_chans=None): + """Check indices parameter for multivariate connectivity and mask it. Parameters ---------- indices : tuple of array of array of int, shape (2, n_cons, variable) Tuple containing index sets. + n_chans : int | None (default None) + The number of channels in the data. Used when converting negative + indices to positive indices. Cannot be ``None`` if negative indices are + present. + Returns ------- indices : tuple of array of array of int, shape of (2, n_cons, max_n_chans) - The indices padded with the invalid channel index ``-1``. + The indices as a masked array. Notes ----- @@ -108,26 +113,35 @@ def check_multivariate_indices(indices): connection must be unique. If the seed and target indices are given as lists or tuples, they will be - converted to numpy arrays. In case the number of channels differs across + converted to numpy arrays. Because the number of channels can differ across connections or between the seeds and targets for a given connection (i.e. - ragged indices), the returned array will be padded with the invalid channel - index ``-1`` according to the maximum number of channels in the seed or - target of any one connection. E.g. the ragged indices of shape ``(2, - n_cons, variable)``:: + ragged/jagged indices), the returned array will be padded out to a 'full' + array with an invalid index (``-1``) according to the maximum number of + channels in the seed or target of any one connection. These invalid + entries are then masked and returned as numpy masked arrays. E.g. the + ragged indices of shape ``(2, n_cons, variable)``:: indices = ([[0, 1], [0, 1 ]], # seeds [[2, 3], [4, 5, 6]]) # targets - would be returned as:: + would be padded to full arrays:: + + indices = ([[0, 1, -1], [0, 1, -1]], # seeds + [[2, 3, -1], [4, 5, 6]]) # targets - indices = (np.array([[0, 1, -1], [0, 1, -1]]), # seeds - np.array([[2, 3, -1], [4, 5, -1]])) # targets + to have shape ``(2, n_cons, max_n_chans)``, where ``max_n_chans = 3``. The + invalid entries are then masked:: - where the indices have been padded with ``-1`` to have shape ``(2, n_cons, - max_n_chans)``, where ``max_n_chans = 3``. More information on working with - multivariate indices and handling connections where the number of seeds and - targets are not equal can be found in the - :doc:`../auto_examples/handling_ragged_arrays` example. + indices = ([[0, 1, --], [0, 1, --]], # seeds + [[2, 3, --], [4, 5, 6]]) # targets + + In case "indices" contains negative values to index channels, these will be + converted to the corresponding positive-valued index before any masking is + applied. + + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. """ if not isinstance(indices, tuple) or len(indices) != 2: raise ValueError('indices must be a tuple of length 2') @@ -137,6 +151,7 @@ def check_multivariate_indices(indices): 'have the same length') n_cons = len(indices[0]) + invalid = -1 max_n_chans = 0 for inds in ([*indices[0], *indices[1]]): @@ -149,17 +164,32 @@ def check_multivariate_indices(indices): 'multivariate indices cannot contain repeated channels within ' 'a seed or target') max_n_chans = max(max_n_chans, len(inds)) + # convert negative to positive indices + if any(idx < 0 for idx in inds): + if n_chans is None: + raise ValueError( + '`n_chans` must be given if there are negative values ' + 'in `indices`') + if any(idx * -1 > n_chans for idx in inds[inds < 0]): + raise ValueError( + 'a channel index is not present in the data' + ) + inds[inds < 0] = inds[inds < 0] % n_chans # pad indices to avoid ragged arrays - padded_indices = (np.full((n_cons, max_n_chans), -1, dtype=np.int32), - np.full((n_cons, max_n_chans), -1, dtype=np.int32)) + padded_indices = (np.full((n_cons, max_n_chans), invalid, dtype=np.int32), + np.full((n_cons, max_n_chans), invalid, dtype=np.int32)) con_i = 0 for seed, target in zip(indices[0], indices[1]): padded_indices[0][con_i, :len(seed)] = seed padded_indices[1][con_i, :len(target)] = target con_i += 1 - return padded_indices + # mask invalid indices + masked_indices = (np.ma.masked_values(padded_indices[0], invalid), + np.ma.masked_values(padded_indices[1], invalid)) + + return masked_indices def seed_target_indices(seeds, targets): @@ -221,8 +251,8 @@ def seed_target_multivariate_indices(seeds, targets): Returns ------- - indices : tuple of array of array of int, shape (2, n_cons, max_n_chans) - The indices padded with the invalid channel index ``-1``. + indices : tuple of array of array of int, shape (2, n_cons, variable) + The indices as a numpy object array. Notes ----- @@ -232,12 +262,8 @@ def seed_target_multivariate_indices(seeds, targets): channels in the data. The length of indices for each connection do not need to be equal. Furthermore, all indices within a connection must be unique. - ``seeds`` and ``targets`` will be expanded such that connectivity will be - computed between each set of seeds and targets. In case the number of - channels differs across connections or between the seeds and targets for a - given connection (i.e. ragged indices), the returned array will be padded - with the invalid channel index ``-1`` according to the maximum number of - channels in the seed or target of any one connection. E.g. ``seeds`` and + Because the number of channels per connection can vary, the indices are + stored as numpy arrays with ``dtype=object``. E.g. ``seeds`` and ``targets``:: seeds = [[0]] @@ -245,15 +271,15 @@ def seed_target_multivariate_indices(seeds, targets): would be returned as:: - indices = (np.array([[0, -1, -1], [0, -1, -1]]), # seeds - np.array([[1, 2, -1], [3, 4, 5]])) # targets + indices = (np.array([[0 ], [0 ]], dtype=object), # seeds + np.array([[1, 2], [3, 4, 5]], dtype=object)) # targets + + Even if the number of channels does not vary, the indices will still be + stored as object arrays for compatability. - where the indices have been padded with ``-1`` to have shape ``(2, n_cons, - max_n_chans)``, where ``n_cons = n_unique_seeds * n_unique_targets`` and - ``max_n_chans = 3``. More information on working with multivariate indices - and handling connections where the number of seeds and targets are not - equal can be found in the :doc:`../auto_examples/handling_ragged_arrays` - example. + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. """ array_like = (np.ndarray, list, tuple) @@ -263,7 +289,6 @@ def seed_target_multivariate_indices(seeds, targets): ): raise TypeError('`seeds` and `targets` must be array-like') - n_chans = [] for inds in [*seeds, *targets]: if not isinstance(inds, array_like): raise TypeError( @@ -271,27 +296,15 @@ def seed_target_multivariate_indices(seeds, targets): if len(inds) != len(np.unique(inds)): raise ValueError( '`seeds` and `targets` cannot contain repeated channels') - n_chans.append(len(inds)) - max_n_chans = max(n_chans) - n_cons = len(seeds) * len(targets) - # pad indices to avoid ragged arrays - padded_seeds = np.full((len(seeds), max_n_chans), -1, dtype=np.int32) - padded_targets = np.full((len(targets), max_n_chans), -1, dtype=np.int32) - for con_i, seed in enumerate(seeds): - padded_seeds[con_i, :len(seed)] = seed - for con_i, target in enumerate(targets): - padded_targets[con_i, :len(target)] = target - - # create final indices - indices = (np.zeros((n_cons, max_n_chans), dtype=np.int32), - np.zeros((n_cons, max_n_chans), dtype=np.int32)) - con_i = 0 - for seed in padded_seeds: - for target in padded_targets: - indices[0][con_i] = seed - indices[1][con_i] = target - con_i += 1 + indices = [[], []] + for seed in seeds: + for target in targets: + indices[0].append(np.array(seed)) + indices[1].append(np.array(target)) + + indices = (np.array(indices[0], dtype=object), + np.array(indices[1], dtype=object)) return indices From afb10292d6d2ab37264033fb866120366bdc0879 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 26 Oct 2023 23:03:37 +0200 Subject: [PATCH 25/40] fix spelling error --- mne_connectivity/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index 5b5f634b..46c3afc7 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -275,7 +275,7 @@ def seed_target_multivariate_indices(seeds, targets): np.array([[1, 2], [3, 4, 5]], dtype=object)) # targets Even if the number of channels does not vary, the indices will still be - stored as object arrays for compatability. + stored as object arrays for compatibility. More information on working with multivariate indices and handling connections where the number of seeds and targets are not equal can be From 776113b5e888bbbd0150efc940833b4804797796 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Fri, 27 Oct 2023 00:23:01 +0200 Subject: [PATCH 26/40] try fix codespell error --- ignore_words.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignore_words.txt b/ignore_words.txt index 6dc3af46..a03d67bd 100644 --- a/ignore_words.txt +++ b/ignore_words.txt @@ -2,4 +2,4 @@ nd adn ba BA -Manuel \ No newline at end of file +manuel \ No newline at end of file From aaa1119a67c9f442848a7f16f4d4323e190857a0 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 2 Nov 2023 10:30:48 +0100 Subject: [PATCH 27/40] Revert "bug fix missing refactoring for example" This reverts commit 6fe682fc7d2ea88266641fd5b823088d122acc6c. --- examples/handling_ragged_arrays.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/handling_ragged_arrays.py b/examples/handling_ragged_arrays.py index ecb4a9c4..09c98059 100644 --- a/examples/handling_ragged_arrays.py +++ b/examples/handling_ragged_arrays.py @@ -14,7 +14,7 @@ import numpy as np -from mne_connectivity import spectral_connectivity_epochs_multivariate +from mne_connectivity import spectral_connectivity_epochs ############################################################################### # Background @@ -44,7 +44,7 @@ # targets = [[2, 3, 4], [4 ]] # # The ``indices`` parameter passed to -# :func:`~mne_connectivity.spectral_connectivity_epochs_multivariate` and +# :func:`~mne_connectivity.spectral_connectivity_epochs` and # :func:`~mne_connectivity.spectral_connectivity_time` must be a tuple of # array-likes, meaning # that the indices can be passed as a tuple of: lists; tuples; or NumPy arrays. @@ -105,7 +105,7 @@ [[2, 3, 4], [4]]) # targets # compute connectivity -con = spectral_connectivity_epochs_multivariate( +con = spectral_connectivity_epochs( data, method='mic', indices=ragged_indices, sfreq=sfreq, fmin=10, fmax=30, verbose=False) patterns = np.array(con.attrs['patterns']) From fb501e2237009ec1006c7e69b5e3fb622a097efa Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 2 Nov 2023 10:34:16 +0100 Subject: [PATCH 28/40] Revert "Squashed commit of the following:" This reverts commit 5bad5fb8913e66f2a46548523725afce3ff385a9. --- examples/handling_ragged_arrays.py | 39 +++--- mne_connectivity/io.py | 6 +- .../spectral/epochs_multivariate.py | 53 ++++---- .../spectral/tests/test_spectral.py | 6 +- mne_connectivity/spectral/time.py | 42 +++---- mne_connectivity/tests/test_utils.py | 16 +-- mne_connectivity/utils/utils.py | 113 ++++++++---------- 7 files changed, 127 insertions(+), 148 deletions(-) diff --git a/examples/handling_ragged_arrays.py b/examples/handling_ragged_arrays.py index 09c98059..67f06a4e 100644 --- a/examples/handling_ragged_arrays.py +++ b/examples/handling_ragged_arrays.py @@ -62,8 +62,8 @@ # ragged_indices = (np.array([[0, 1 ], [0, 1, 2, 3]], dtype='object'), # np.array([[2, 3, 4], [4 ]], dtype='object')) # -# **N.B. Note that since NumPy v1.19.0, dtype='object' must be specified when -# forming ragged arrays.** +# **N.B. Note that when forming ragged arrays in NumPy, dtype='object' must be +# specified.** # # Just as for bivariate connectivity, the length of ``indices[0]`` and # ``indices[1]`` is equal (i.e. the number of connections), however information @@ -71,16 +71,19 @@ # array. Importantly, these indices are ragged, as the first connection will be # computed between 2 seed and 3 target channels, and the second connection # between 4 seed and 1 target channel. The connectivity functions will -# recognise the indices as being ragged, and pad them to a 'full' array by -# adding placeholder values which are masked accordingly. This makes the -# indices easier to work with, and also compatible with the engine used to save -# connectivity objects. For example, the above indices would become:: +# recognise the indices as being ragged, and pad them accordingly to make them +# easier to work with and compatible with the h5netcdf saving engine. The known +# value used to pad the arrays is ``-1``, an invalid channel index. The above +# indices would be padded to:: # -# padded_indices = (np.array([[0, 1, --, --], [0, 1, 2, 3]]), -# np.array([[2, 3, 4, --], [4, --, --, --]])) +# padded_indices = (np.array([[0, 1, -1, -1], [0, 1, 2, 3]]), +# np.array([[2, 3, 4, -1], [4, -1, -1, -1]])) # -# where ``--`` are masked entries. These indices are what is stored in the -# returned connectivity objects. +# These indices are what is stored in the connectivity object, and is also the +# format of indices returned from the helper functions +# :func:`~mne_connectivity.check_multivariate_indices` and +# :func:`~mne_connectivity.seed_target_multivariate_indices`. It is also +# possible to pass the padded indices to the connectivity functions directly. # # For the connectivity results themselves, the methods available in # MNE-Connectivity combine information across the different channels into a @@ -115,11 +118,11 @@ max_n_chans = max( [len(inds) for inds in ([*ragged_indices[0], *ragged_indices[1]])]) -# show that the padded indices entries are masked -assert np.sum(padded_indices[0][0].mask) == 2 # 2 padded channels -assert np.sum(padded_indices[1][0].mask) == 1 # 1 padded channels -assert np.sum(padded_indices[0][1].mask) == 0 # 0 padded channels -assert np.sum(padded_indices[1][1].mask) == 3 # 3 padded channels +# show that the padded indices entries are all -1 +assert np.count_nonzero(padded_indices[0][0] == -1) == 2 # 2 padded channels +assert np.count_nonzero(padded_indices[1][0] == -1) == 1 # 1 padded channels +assert np.count_nonzero(padded_indices[0][1] == -1) == 0 # 0 padded channels +assert np.count_nonzero(padded_indices[1][1] == -1) == 3 # 3 padded channels # patterns have shape [seeds/targets x cons x max channels x freqs (x times)] assert patterns.shape == (2, n_cons, max_n_chans, n_freqs) @@ -134,11 +137,11 @@ seed_patterns_con1 = patterns[0, 0, :len(ragged_indices[0][0])] target_patterns_con1 = patterns[1, 0, :len(ragged_indices[1][0])] -# extract patterns for second connection using the padded, masked indices +# extract patterns for second connection using the padded indices (pad = -1) seed_patterns_con2 = ( - patterns[0, 1, :padded_indices[0][1].count()]) + patterns[0, 1, :np.count_nonzero(padded_indices[0][1] != -1)]) target_patterns_con2 = ( - patterns[1, 1, :padded_indices[1][1].count()]) + patterns[1, 1, :np.count_nonzero(padded_indices[1][1] != -1)]) # show that shapes of patterns are correct assert seed_patterns_con1.shape == (2, n_freqs) # channels (0, 1) diff --git a/mne_connectivity/io.py b/mne_connectivity/io.py index e8d9b916..63aa3501 100644 --- a/mne_connectivity/io.py +++ b/mne_connectivity/io.py @@ -53,11 +53,9 @@ def _xarray_to_conn(array, cls_func): event_id = dict(zip(event_id_keys, event_id_vals)) array.attrs['event_id'] = event_id - # convert indices numpy arrays to a tuple of masked arrays - # (only multivariate connectivity indices saved as arrays) + # convert indices numpy arrays to a tuple of arrays if isinstance(array.attrs['indices'], np.ndarray): - array.attrs['indices'] = tuple( - np.ma.masked_values(array.attrs['indices'], -1)) + array.attrs['indices'] = tuple(array.attrs['indices']) # create the connectivity class conn = cls_func( diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index db648456..28077adb 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -31,18 +31,12 @@ def _check_indices(indices, method, n_signals): logger.info('using all indices for multivariate connectivity') indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], np.arange(n_signals, dtype=int)[np.newaxis, :]) - indices_use = np.ma.masked_array(indices_use, - mask=False, fill_value=-1) else: - indices_use = check_multivariate_indices(indices) # mask indices - indices_use = np.ma.concatenate([inds[np.newaxis] for inds in - indices_use]) - np.ma.set_fill_value(indices_use, -1) # else 99999 after concat. + indices_use = check_multivariate_indices(indices) # pad with -1 if any(this_method in _gc_methods for this_method in method): - for seed, target in zip(indices_use[0], indices_use[1]): - intersection = np.intersect1d(seed.compressed(), - target.compressed()) - if intersection.size > 0: + for seed, target in zip(indices[0], indices[1]): + intersection = np.intersect1d(seed, target) + if np.any(intersection != -1): # ignore padded entries raise ValueError( 'seed and target indices must not intersect when ' 'computing Granger causality') @@ -65,10 +59,16 @@ def _check_rank_input(rank, data, indices): else: data_arr = data + # XXX: Unpadding of arrays after already padding them is perhaps not so + # efficient. However, we need to remove the padded values to + # ensure only the correct channels are indexed, and having two + # versions of indices is a bit messy currently. A candidate for + # refactoring to simplify code. + for group_i in range(2): # seeds and targets for con_i, con_idcs in enumerate(indices[group_i]): - s = np.linalg.svd(data_arr[:, con_idcs.compressed()], - compute_uv=False) + con_idcs = con_idcs[con_idcs != -1] # -1 is padded value + s = np.linalg.svd(data_arr[:, con_idcs], compute_uv=False) rank[group_i][con_i] = np.min( [np.count_nonzero(epoch >= epoch[0] * sv_tol) for epoch in s]) @@ -197,8 +197,8 @@ def compute_con(self, indices, ranks, n_epochs=1): indices[0], indices[1], ranks[0], ranks[1]): self._log_connection_number(con_i) - seed_idcs = seed_idcs.compressed() - target_idcs = target_idcs.compressed() + seed_idcs = seed_idcs[seed_idcs != -1] + target_idcs = target_idcs[target_idcs != -1] con_idcs = [*seed_idcs, *target_idcs] C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] @@ -432,8 +432,8 @@ def compute_con(self, indices, ranks, n_epochs=1): indices[0], indices[1], ranks[0], ranks[1]): self._log_connection_number(con_i) - seed_idcs = seed_idcs.compressed() - target_idcs = target_idcs.compressed() + seed_idcs = seed_idcs[seed_idcs != -1] + target_idcs = target_idcs[target_idcs != -1] con_idcs = [*seed_idcs, *target_idcs] C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] @@ -1009,8 +1009,7 @@ def spectral_connectivity_epochs_multivariate( # make sure padded indices are stored in the connectivity object if indices is not None: - # create a copy - indices = (indices_use[0].copy(), indices_use[1].copy()) + indices = tuple(np.array(indices_use)) # create a copy # get the window function, wavelets, etc for different modes (spectral_params, mt_adaptive, n_times_spectrum, @@ -1021,12 +1020,20 @@ def spectral_connectivity_epochs_multivariate( cwt_freqs=cwt_freqs, freqs=freqs, freq_mask=freq_mask) # unique signals for which we actually need to compute CSD - sig_idx = np.unique(indices_use.compressed()) + sig_idx = np.unique(np.concatenate(np.concatenate( + indices_use))) + sig_idx = sig_idx[sig_idx != -1] remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(sig_idx)} - remapped_inds = indices_use.copy() - for idx in sig_idx: - remapped_inds[indices_use == idx] = remapping[idx] - remapped_sig = np.unique(remapped_inds.compressed()) + remapping[-1] = -1 + remapped_inds = (indices_use[0].copy(), indices_use[1].copy()) + con_i = 0 + for seed, target in zip(indices_use[0], indices_use[1]): + remapped_inds[0][con_i] = np.array([ + remapping[idx] for idx in seed]) + remapped_inds[1][con_i] = np.array([ + remapping[idx] for idx in target]) + con_i += 1 + remapped_sig = [remapping[idx] for idx in sig_idx] n_signals_use = len(sig_idx) # map indices to unique indices diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index d05ec61e..54bfafa5 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -1392,11 +1392,7 @@ def test_multivar_spectral_connectivity_indices_roundtrip_io( assert isinstance(con.indices, tuple) and isinstance( read_con.indices, tuple ) - # check indices are masked - assert all([np.ma.isMA(inds) for inds in con.indices] and - [np.ma.isMA(inds) for inds in read_con.indices]) # check indices have same values - assert np.all([con_inds == read_inds for con_inds, read_inds in - zip(con.indices, read_con.indices)]) + assert np.all(np.array(con.indices) == np.array(read_con.indices)) else: assert con.indices is None and read_con.indices is None diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index d44f5be3..d0059ace 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -408,50 +408,46 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, 'indices must be specified when computing Granger ' 'causality, as all-to-all connectivity is not supported') logger.info('using all indices for multivariate connectivity') - indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], - np.arange(n_signals, dtype=int)[np.newaxis, :]) - indices_use = np.ma.masked_array(indices_use, - mask=False, fill_value=-1) + indices_use = (np.array([np.arange(n_signals, dtype=np.int32)]), + np.array([np.arange(n_signals, dtype=np.int32)])) else: logger.info('only using indices for lower-triangular matrix') indices_use = np.tril_indices(n_signals, k=-1) else: if multivariate_con: - indices_use = check_multivariate_indices(indices) # mask indices - indices_use = np.ma.concatenate([inds[np.newaxis] for inds in - indices_use]) - np.ma.set_fill_value(indices_use, -1) # else 99999 after concat. + indices_use = check_multivariate_indices(indices) # pad with -1 if any(this_method in _gc_methods for this_method in method): - for seed, target in zip(indices_use[0], indices_use[1]): - intersection = np.intersect1d(seed.compressed(), - target.compressed()) - if intersection.size > 0: + for seed, target in zip(indices[0], indices[1]): + intersection = np.intersect1d(seed, target) + if np.any(intersection != -1): # ignore padded entries raise ValueError( 'seed and target indices must not intersect when ' 'computing Granger causality') # make sure padded indices are stored in the connectivity object - # create a copy - indices = (indices_use[0].copy(), indices_use[1].copy()) + indices = tuple(np.array(indices_use)) # create a copy else: indices_use = check_indices(indices) - n_cons = len(indices_use[0]) + # create copies of indices_use for independent manipulation + source_idx = np.array(indices_use[0]) + target_idx = np.array(indices_use[1]) + n_cons = len(source_idx) # unique signals for which we actually need to compute the CSD of if multivariate_con: - signals_use = np.unique(indices_use.compressed()) + signals_use = np.unique(np.concatenate(np.concatenate(indices_use))) + signals_use = signals_use[signals_use != -1] remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(signals_use)} - remapped_inds = indices_use.copy() + remapping[-1] = -1 # multivariate functions expect seed/target remapping - for idx in signals_use: - remapped_inds[indices_use == idx] = remapping[idx] - source_idx = remapped_inds[0] - target_idx = remapped_inds[1] + con_i = 0 + for seed, target in zip(indices_use[0], indices_use[1]): + source_idx[con_i] = np.array([remapping[idx] for idx in seed]) + target_idx[con_i] = np.array([remapping[idx] for idx in target]) + con_i += 1 max_n_channels = len(indices_use[0][0]) else: # no indices remapping required for bivariate functions signals_use = np.unique(np.r_[indices_use[0], indices_use[1]]) - source_idx = indices_use[0].copy() - target_idx = indices_use[1].copy() max_n_channels = len(indices_use[0]) # check rank input and compute data ranks if necessary diff --git a/mne_connectivity/tests/test_utils.py b/mne_connectivity/tests/test_utils.py index 31ed3c79..1e5822eb 100644 --- a/mne_connectivity/tests/test_utils.py +++ b/mne_connectivity/tests/test_utils.py @@ -34,22 +34,14 @@ def test_seed_target_indices(): seeds = [[0, 1]] targets = [[2, 3], [3, 4]] indices = seed_target_multivariate_indices(seeds, targets) - match_indices = (np.array([[0, 1], [0, 1]], dtype=object), - np.array([[2, 3], [3, 4]], dtype=object)) - for type_i in range(2): - for con_i in range(len(indices[0])): - assert np.all(indices[type_i][con_i] == - match_indices[type_i][con_i]) + assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]), + np.array([[2, 3], [3, 4]]))) # ragged indices seeds = [[0, 1]] targets = [[2, 3, 4], [4]] indices = seed_target_multivariate_indices(seeds, targets) - match_indices = (np.array([[0, 1], [0, 1]], dtype=object), - np.array([[2, 3, 4], [4]], dtype=object)) - for type_i in range(2): - for con_i in range(len(indices[0])): - assert np.all(indices[type_i][con_i] == - match_indices[type_i][con_i]) + assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]), + np.array([[2, 3, 4], [4, -1, -1]]))) # test error catching # non-array-like seeds/targets with pytest.raises(TypeError, diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index 46c3afc7..16f56f09 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -83,23 +83,18 @@ def check_indices(indices): return indices -def check_multivariate_indices(indices, n_chans=None): - """Check indices parameter for multivariate connectivity and mask it. +def check_multivariate_indices(indices): + """Check indices parameter for multivariate connectivity and pad it. Parameters ---------- indices : tuple of array of array of int, shape (2, n_cons, variable) Tuple containing index sets. - n_chans : int | None (default None) - The number of channels in the data. Used when converting negative - indices to positive indices. Cannot be ``None`` if negative indices are - present. - Returns ------- indices : tuple of array of array of int, shape of (2, n_cons, max_n_chans) - The indices as a masked array. + The indices padded with the invalid channel index ``-1``. Notes ----- @@ -113,35 +108,26 @@ def check_multivariate_indices(indices, n_chans=None): connection must be unique. If the seed and target indices are given as lists or tuples, they will be - converted to numpy arrays. Because the number of channels can differ across + converted to numpy arrays. In case the number of channels differs across connections or between the seeds and targets for a given connection (i.e. - ragged/jagged indices), the returned array will be padded out to a 'full' - array with an invalid index (``-1``) according to the maximum number of - channels in the seed or target of any one connection. These invalid - entries are then masked and returned as numpy masked arrays. E.g. the - ragged indices of shape ``(2, n_cons, variable)``:: + ragged indices), the returned array will be padded with the invalid channel + index ``-1`` according to the maximum number of channels in the seed or + target of any one connection. E.g. the ragged indices of shape ``(2, + n_cons, variable)``:: indices = ([[0, 1], [0, 1 ]], # seeds [[2, 3], [4, 5, 6]]) # targets - would be padded to full arrays:: - - indices = ([[0, 1, -1], [0, 1, -1]], # seeds - [[2, 3, -1], [4, 5, 6]]) # targets - - to have shape ``(2, n_cons, max_n_chans)``, where ``max_n_chans = 3``. The - invalid entries are then masked:: - - indices = ([[0, 1, --], [0, 1, --]], # seeds - [[2, 3, --], [4, 5, 6]]) # targets + would be returned as:: - In case "indices" contains negative values to index channels, these will be - converted to the corresponding positive-valued index before any masking is - applied. + indices = (np.array([[0, 1, -1], [0, 1, -1]]), # seeds + np.array([[2, 3, -1], [4, 5, -1]])) # targets - More information on working with multivariate indices and handling - connections where the number of seeds and targets are not equal can be - found in the :doc:`../auto_examples/handling_ragged_arrays` example. + where the indices have been padded with ``-1`` to have shape ``(2, n_cons, + max_n_chans)``, where ``max_n_chans = 3``. More information on working with + multivariate indices and handling connections where the number of seeds and + targets are not equal can be found in the + :doc:`../auto_examples/handling_ragged_arrays` example. """ if not isinstance(indices, tuple) or len(indices) != 2: raise ValueError('indices must be a tuple of length 2') @@ -151,7 +137,6 @@ def check_multivariate_indices(indices, n_chans=None): 'have the same length') n_cons = len(indices[0]) - invalid = -1 max_n_chans = 0 for inds in ([*indices[0], *indices[1]]): @@ -164,32 +149,17 @@ def check_multivariate_indices(indices, n_chans=None): 'multivariate indices cannot contain repeated channels within ' 'a seed or target') max_n_chans = max(max_n_chans, len(inds)) - # convert negative to positive indices - if any(idx < 0 for idx in inds): - if n_chans is None: - raise ValueError( - '`n_chans` must be given if there are negative values ' - 'in `indices`') - if any(idx * -1 > n_chans for idx in inds[inds < 0]): - raise ValueError( - 'a channel index is not present in the data' - ) - inds[inds < 0] = inds[inds < 0] % n_chans # pad indices to avoid ragged arrays - padded_indices = (np.full((n_cons, max_n_chans), invalid, dtype=np.int32), - np.full((n_cons, max_n_chans), invalid, dtype=np.int32)) + padded_indices = (np.full((n_cons, max_n_chans), -1, dtype=np.int32), + np.full((n_cons, max_n_chans), -1, dtype=np.int32)) con_i = 0 for seed, target in zip(indices[0], indices[1]): padded_indices[0][con_i, :len(seed)] = seed padded_indices[1][con_i, :len(target)] = target con_i += 1 - # mask invalid indices - masked_indices = (np.ma.masked_values(padded_indices[0], invalid), - np.ma.masked_values(padded_indices[1], invalid)) - - return masked_indices + return padded_indices def seed_target_indices(seeds, targets): @@ -251,8 +221,8 @@ def seed_target_multivariate_indices(seeds, targets): Returns ------- - indices : tuple of array of array of int, shape (2, n_cons, variable) - The indices as a numpy object array. + indices : tuple of array of array of int, shape (2, n_cons, max_n_chans) + The indices padded with the invalid channel index ``-1``. Notes ----- @@ -262,8 +232,12 @@ def seed_target_multivariate_indices(seeds, targets): channels in the data. The length of indices for each connection do not need to be equal. Furthermore, all indices within a connection must be unique. - Because the number of channels per connection can vary, the indices are - stored as numpy arrays with ``dtype=object``. E.g. ``seeds`` and + ``seeds`` and ``targets`` will be expanded such that connectivity will be + computed between each set of seeds and targets. In case the number of + channels differs across connections or between the seeds and targets for a + given connection (i.e. ragged indices), the returned array will be padded + with the invalid channel index ``-1`` according to the maximum number of + channels in the seed or target of any one connection. E.g. ``seeds`` and ``targets``:: seeds = [[0]] @@ -271,8 +245,8 @@ def seed_target_multivariate_indices(seeds, targets): would be returned as:: - indices = (np.array([[0 ], [0 ]], dtype=object), # seeds - np.array([[1, 2], [3, 4, 5]], dtype=object)) # targets + indices = (np.array([[0 ], [0 ]]), # seeds + np.array([[1, 2], [3, 4, 5]])) # targets Even if the number of channels does not vary, the indices will still be stored as object arrays for compatibility. @@ -289,6 +263,7 @@ def seed_target_multivariate_indices(seeds, targets): ): raise TypeError('`seeds` and `targets` must be array-like') + n_chans = [] for inds in [*seeds, *targets]: if not isinstance(inds, array_like): raise TypeError( @@ -296,15 +271,27 @@ def seed_target_multivariate_indices(seeds, targets): if len(inds) != len(np.unique(inds)): raise ValueError( '`seeds` and `targets` cannot contain repeated channels') + n_chans.append(len(inds)) + max_n_chans = max(n_chans) + n_cons = len(seeds) * len(targets) - indices = [[], []] - for seed in seeds: - for target in targets: - indices[0].append(np.array(seed)) - indices[1].append(np.array(target)) - - indices = (np.array(indices[0], dtype=object), - np.array(indices[1], dtype=object)) + # pad indices to avoid ragged arrays + padded_seeds = np.full((len(seeds), max_n_chans), -1, dtype=np.int32) + padded_targets = np.full((len(targets), max_n_chans), -1, dtype=np.int32) + for con_i, seed in enumerate(seeds): + padded_seeds[con_i, :len(seed)] = seed + for con_i, target in enumerate(targets): + padded_targets[con_i, :len(target)] = target + + # create final indices + indices = (np.zeros((n_cons, max_n_chans), dtype=np.int32), + np.zeros((n_cons, max_n_chans), dtype=np.int32)) + con_i = 0 + for seed in padded_seeds: + for target in padded_targets: + indices[0][con_i] = seed + indices[1][con_i] = target + con_i += 1 return indices From e27bf6144ccbd71cdd52e672bb831b0881b6c6e5 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 2 Nov 2023 10:35:12 +0100 Subject: [PATCH 29/40] Revert "Squashed commit of the following:" This reverts commit 5bad5fb8913e66f2a46548523725afce3ff385a9. --- doc/api.rst | 1 - examples/granger_causality.py | 14 +- examples/mic_mim.py | 10 +- mne_connectivity/__init__.py | 4 +- mne_connectivity/spectral/__init__.py | 3 +- mne_connectivity/spectral/epochs.py | 1693 ++++++++++++++++- mne_connectivity/spectral/epochs_bivariate.py | 729 ------- .../spectral/epochs_multivariate.py | 1129 ----------- .../spectral/tests/test_spectral.py | 158 +- mne_connectivity/spectral/time.py | 5 +- 10 files changed, 1693 insertions(+), 2053 deletions(-) delete mode 100644 mne_connectivity/spectral/epochs_bivariate.py delete mode 100644 mne_connectivity/spectral/epochs_multivariate.py diff --git a/doc/api.rst b/doc/api.rst index 3fe85832..c91f9c02 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -47,7 +47,6 @@ on numpy array inputs. phase_slope_index vector_auto_regression spectral_connectivity_epochs - spectral_connectivity_epochs_multivariate spectral_connectivity_time Reading functions diff --git a/examples/granger_causality.py b/examples/granger_causality.py index 4129dadc..64a657db 100644 --- a/examples/granger_causality.py +++ b/examples/granger_causality.py @@ -20,7 +20,7 @@ import mne from mne.datasets.fieldtrip_cmc import data_path -from mne_connectivity import spectral_connectivity_epochs_multivariate +from mne_connectivity import spectral_connectivity_epochs ############################################################################### # Background @@ -161,10 +161,10 @@ indices_ba = (np.array([signals_b]), np.array([signals_a])) # B => A # compute Granger causality -gc_ab = spectral_connectivity_epochs_multivariate( +gc_ab = spectral_connectivity_epochs( epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=(np.array([5]), np.array([5])), gc_n_lags=20) # A => B -gc_ba = spectral_connectivity_epochs_multivariate( +gc_ba = spectral_connectivity_epochs( epochs, method=['gc'], indices=indices_ba, fmin=5, fmax=30, rank=(np.array([5]), np.array([5])), gc_n_lags=20) # B => A freqs = gc_ab.freqs @@ -262,10 +262,10 @@ # %% # compute GC on time-reversed signals -gc_tr_ab = spectral_connectivity_epochs_multivariate( +gc_tr_ab = spectral_connectivity_epochs( epochs, method=['gc_tr'], indices=indices_ab, fmin=5, fmax=30, rank=(np.array([5]), np.array([5])), gc_n_lags=20) # TR[A => B] -gc_tr_ba = spectral_connectivity_epochs_multivariate( +gc_tr_ba = spectral_connectivity_epochs( epochs, method=['gc_tr'], indices=indices_ba, fmin=5, fmax=30, rank=(np.array([5]), np.array([5])), gc_n_lags=20) # TR[B => A] @@ -317,7 +317,7 @@ # %% -gc_ab_60 = spectral_connectivity_epochs_multivariate( +gc_ab_60 = spectral_connectivity_epochs( epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=(np.array([5]), np.array([5])), gc_n_lags=60) # A => B @@ -375,7 +375,7 @@ # %% try: - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=None, gc_n_lags=20, verbose=False) # A => B print('Success!') diff --git a/examples/mic_mim.py b/examples/mic_mim.py index 62674e75..87111586 100644 --- a/examples/mic_mim.py +++ b/examples/mic_mim.py @@ -25,9 +25,7 @@ import mne from mne import EvokedArray, make_fixed_length_epochs from mne.datasets.fieldtrip_cmc import data_path -from mne_connectivity import (seed_target_indices, - spectral_connectivity_epochs, - spectral_connectivity_epochs_multivariate) +from mne_connectivity import seed_target_indices, spectral_connectivity_epochs ############################################################################### # Background @@ -89,7 +87,7 @@ target_names = [epochs.info['ch_names'][idx] for idx in targets] # multivariate imaginary part of coherency -(mic, mim) = spectral_connectivity_epochs_multivariate( +(mic, mim) = spectral_connectivity_epochs( epochs, method=['mic', 'mim'], indices=multivar_indices, fmin=5, fmax=30, rank=None) @@ -292,7 +290,7 @@ # %% indices = (np.array([[*seeds, *targets]]), np.array([[*seeds, *targets]])) -gim = spectral_connectivity_epochs_multivariate( +gim = spectral_connectivity_epochs( epochs, method='mim', indices=indices, fmin=5, fmax=30, rank=None, verbose=False) @@ -344,7 +342,7 @@ # %% -(mic_red, mim_red) = spectral_connectivity_epochs_multivariate( +(mic_red, mim_red) = spectral_connectivity_epochs( epochs, method=['mic', 'mim'], indices=multivar_indices, fmin=5, fmax=30, rank=([25], [25])) diff --git a/mne_connectivity/__init__.py b/mne_connectivity/__init__.py index 32488b33..c2f03a6c 100644 --- a/mne_connectivity/__init__.py +++ b/mne_connectivity/__init__.py @@ -15,9 +15,7 @@ from .effective import phase_slope_index from .envelope import envelope_correlation, symmetric_orth from .io import read_connectivity -from .spectral import (spectral_connectivity_time, - spectral_connectivity_epochs, - spectral_connectivity_epochs_multivariate) +from .spectral import spectral_connectivity_time, spectral_connectivity_epochs from .vector_ar import vector_auto_regression, select_order from .utils import (check_indices, check_multivariate_indices, degree, seed_target_indices, seed_target_multivariate_indices) diff --git a/mne_connectivity/spectral/__init__.py b/mne_connectivity/spectral/__init__.py index f2252db9..a0f06ef6 100644 --- a/mne_connectivity/spectral/__init__.py +++ b/mne_connectivity/spectral/__init__.py @@ -1,3 +1,2 @@ -from .epochs_bivariate import spectral_connectivity_epochs -from .epochs_multivariate import spectral_connectivity_epochs_multivariate +from .epochs import spectral_connectivity_epochs from .time import spectral_connectivity_time \ No newline at end of file diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 5742ae33..7ad551c9 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -2,6 +2,8 @@ # Denis A. Engemann # Adam Li # Thomas S. Binns +# Tien D. Nguyen +# Richard M. Köhler # # License: BSD (3-clause) @@ -9,16 +11,20 @@ import inspect import numpy as np +import scipy as sp from mne.epochs import BaseEpochs from mne.parallel import parallel_func from mne.source_estimate import _BaseSourceEstimate -from mne.time_frequency.multitaper import ( - _csd_from_mt, _mt_spectra, _psd_from_mt, _psd_from_mt_adaptive) +from mne.time_frequency.multitaper import (_csd_from_mt, + _mt_spectra, _psd_from_mt, + _psd_from_mt_adaptive) from mne.time_frequency.tfr import cwt, morlet from mne.time_frequency.multitaper import _compute_mt_params -from mne.utils import _arange_div, _check_option, _time_mask, logger, warn +from mne.utils import ( + ProgressBar, _arange_div, _check_option, _time_mask, logger, warn, verbose) -from ..base import SpectralConnectivity, SpectroTemporalConnectivity +from ..base import (SpectralConnectivity, SpectroTemporalConnectivity) +from ..utils import fill_doc, check_indices, check_multivariate_indices def _compute_freqs(n_times, sfreq, cwt_freqs, mode): @@ -57,8 +63,10 @@ def _compute_freq_mask(freqs_all, fmin, fmax, fskip): return freq_mask -def _prepare_connectivity(epoch_block, times_in, tmin, tmax, fmin, fmax, sfreq, - mode, fskip, n_bands, cwt_freqs, faverage): +def _prepare_connectivity(epoch_block, times_in, tmin, tmax, + fmin, fmax, sfreq, indices, + method, mode, fskip, n_bands, + cwt_freqs, faverage): """Check and precompute dimensions of results data.""" first_epoch = epoch_block[0] @@ -84,6 +92,43 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, fmin, fmax, sfreq, times = times_in[tmin_idx:tmax_idx] n_times = len(times) + if any(this_method in _multivariate_methods for this_method in method): + multivariate_con = True + else: + multivariate_con = False + + if indices is None: + if multivariate_con: + if any(this_method in _gc_methods for this_method in method): + raise ValueError( + 'indices must be specified when computing Granger ' + 'causality, as all-to-all connectivity is not supported') + else: + logger.info('using all indices for multivariate connectivity') + indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], + np.arange(n_signals, dtype=int)[np.newaxis, :]) + else: + logger.info('only using indices for lower-triangular matrix') + # only compute r for lower-triangular region + indices_use = np.tril_indices(n_signals, -1) + else: + if multivariate_con: + indices_use = check_multivariate_indices(indices) # pad with -1 + if any(this_method in _gc_methods for this_method in method): + for seed, target in zip(indices[0], indices[1]): + intersection = np.intersect1d(seed, target) + if np.any(intersection != -1): # ignore padded entries + raise ValueError( + 'seed and target indices must not intersect when ' + 'computing Granger causality') + else: + indices_use = check_indices(indices) + + # number of connectivities to compute + n_cons = len(indices_use[0]) + + logger.info(' computing connectivity for %d connections' + % n_cons) logger.info(' using t=%0.3fs..%0.3fs for estimation (%d points)' % (tmin_true, tmax_true, n_times)) @@ -139,9 +184,55 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, fmin, fmax, sfreq, logger.info(' connectivity scores will be averaged for ' 'each band') - return (times, n_times, times_in, n_times_in, tmin_idx, + return (n_cons, times, n_times, times_in, n_times_in, tmin_idx, tmax_idx, n_freqs, freq_mask, freqs, freqs_bands, freq_idx_bands, - n_signals, warn_times) + n_signals, indices_use, warn_times) + + +def _check_rank_input(rank, data, indices): + """Check the rank argument is appropriate and compute rank if missing.""" + sv_tol = 1e-10 # tolerance for non-zero singular val (rel to largest) + if rank is None: + rank = np.zeros((2, len(indices[0])), dtype=int) + + if isinstance(data, BaseEpochs): + data_arr = data.get_data() + else: + data_arr = data + + # XXX: Unpadding of arrays after already padding them is perhaps not so + # efficient. However, we need to remove the padded values to + # ensure only the correct channels are indexed, and having two + # versions of indices is a bit messy currently. A candidate for + # refactoring to simplify code. + + for group_i in range(2): # seeds and targets + for con_i, con_idcs in enumerate(indices[group_i]): + con_idcs = con_idcs[con_idcs != -1] # -1 is padded value + s = np.linalg.svd(data_arr[:, con_idcs], compute_uv=False) + rank[group_i][con_i] = np.min( + [np.count_nonzero(epoch >= epoch[0] * sv_tol) + for epoch in s]) + + logger.info('Estimated data ranks:') + con_i = 1 + for seed_rank, target_rank in zip(rank[0], rank[1]): + logger.info(' connection %i - seeds (%i); targets (%i)' + % (con_i, seed_rank, target_rank, )) + con_i += 1 + rank = tuple((np.array(rank[0]), np.array(rank[1]))) + + else: + for seed_idcs, target_idcs, seed_rank, target_rank in zip( + indices[0], indices[1], rank[0], rank[1]): + if not (0 < seed_rank <= len(seed_idcs) and + 0 < target_rank <= len(target_idcs)): + raise ValueError( + 'ranks for seeds and targets must be > 0 and <= the ' + 'number of channels in the seeds and targets, ' + 'respectively, for each connection') + + return rank def _assemble_spectral_params(mode, n_times, mt_adaptive, mt_bandwidth, sfreq, @@ -183,46 +274,6 @@ def _assemble_spectral_params(mode, n_times, mt_adaptive, mt_bandwidth, sfreq, return spectral_params, mt_adaptive, n_times_spectrum, n_tapers -def _compute_spectral_methods_epochs( - con_methods, epoch_block, epoch_idx, call_params, parallel, - my_spectral_connectivity_epochs, n_jobs, n_times_in, times_in, - warn_times -): - """Compute CSD/PSD for spectral_connectivity_epochs... functions.""" - # check dimensions and time scale - for this_epoch in epoch_block: - _, _, _, warn_times = _get_and_verify_data_sizes( - this_epoch, call_params["sfreq"], call_params["n_signals"], - n_times_in, times_in, warn_times=warn_times) - - if n_jobs == 1: - # no parallel processing - for this_epoch in epoch_block: - logger.info(' computing cross-spectral density for epoch %d' - % (epoch_idx + 1)) - # con methods and psd are updated inplace - _epoch_spectral_connectivity(data=this_epoch, **call_params) - epoch_idx += 1 - else: - # process epochs in parallel - logger.info( - ' computing cross-spectral density for epochs %d..%d' - % (epoch_idx + 1, epoch_idx + len(epoch_block))) - - out = parallel(my_spectral_connectivity_epochs( - data=this_epoch, **call_params) - for this_epoch in epoch_block) - # do the accumulation - for this_out in out: - for _method, parallel_method in zip(con_methods, this_out[0]): - _method.combine(parallel_method) - if call_params["psd"] is not None: - call_params["psd"] += this_out[1] - - epoch_idx += len(epoch_block) - - return epoch_idx - ######################################################################## # Various connectivity estimators @@ -242,9 +293,996 @@ def combine(self, other): def compute_con(self, con_idx, n_epochs): raise NotImplementedError('compute_con method not implemented') + +class _EpochMeanConEstBase(_AbstractConEstBase): + """Base class for methods that estimate connectivity as mean epoch-wise.""" + + patterns = None + + def __init__(self, n_cons, n_freqs, n_times): + self.n_cons = n_cons + self.n_freqs = n_freqs + self.n_times = n_times + + if n_times == 0: + self.csd_shape = (n_cons, n_freqs) + else: + self.csd_shape = (n_cons, n_freqs, n_times) + + self.con_scores = None + + def start_epoch(self): # noqa: D401 + """Called at the start of each epoch.""" + pass # for this type of con. method we don't do anything + + def combine(self, other): + """Include con. accumated for some epochs in this estimate.""" + self._acc += other._acc + + +class _EpochMeanMultivariateConEstBase(_AbstractConEstBase): + """Base class for mean epoch-wise multivar. con. estimation methods.""" + + n_steps = None + patterns = None + + def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): + self.n_signals = n_signals + self.n_cons = n_cons + self.n_freqs = n_freqs + self.n_times = n_times + self.n_jobs = n_jobs + + # include time dimension, even when unused for indexing flexibility + if n_times == 0: + self.csd_shape = (n_signals**2, n_freqs) + self.con_scores = np.zeros((n_cons, n_freqs, 1)) + else: + self.csd_shape = (n_signals**2, n_freqs, n_times) + self.con_scores = np.zeros((n_cons, n_freqs, n_times)) + + # allocate space for accumulation of CSD + self._acc = np.zeros(self.csd_shape, dtype=np.complex128) + + self._compute_n_progress_bar_steps() + + def start_epoch(self): # noqa: D401 + """Called at the start of each epoch.""" + pass # for this type of con. method we don't do anything + + def combine(self, other): + """Include con. accumulated for some epochs in this estimate.""" + self._acc += other._acc + + def accumulate(self, con_idx, csd_xy): + """Accumulate CSD for some connections.""" + self._acc[con_idx] += csd_xy + + def _compute_n_progress_bar_steps(self): + """Calculate the number of steps to include in the progress bar.""" + self.n_steps = int(np.ceil(self.n_freqs / self.n_jobs)) + + def _log_connection_number(self, con_i): + """Log the number of the connection being computed.""" + logger.info('Computing %s for connection %i of %i' + % (self.name, con_i + 1, self.n_cons, )) + + def _get_block_indices(self, block_i, limit): + """Get indices for a computation block capped by a limit.""" + indices = np.arange(block_i * self.n_jobs, (block_i + 1) * self.n_jobs) + + return indices[np.nonzero(indices < limit)] + + def reshape_csd(self): + """Reshape CSD into a matrix of times x freqs x signals x signals.""" + if self.n_times == 0: + return (np.reshape(self._acc, ( + self.n_signals, self.n_signals, self.n_freqs, 1) + ).transpose(3, 2, 0, 1)) + + return (np.reshape(self._acc, ( + self.n_signals, self.n_signals, self.n_freqs, self.n_times) + ).transpose(3, 2, 0, 1)) + + +class _CohEstBase(_EpochMeanConEstBase): + """Base Estimator for Coherence, Coherency, Imag. Coherence.""" + + accumulate_psd = True + + def __init__(self, n_cons, n_freqs, n_times): + super(_CohEstBase, self).__init__(n_cons, n_freqs, n_times) + + # allocate space for accumulation of CSD + self._acc = np.zeros(self.csd_shape, dtype=np.complex128) + + def accumulate(self, con_idx, csd_xy): + """Accumulate CSD for some connections.""" + self._acc[con_idx] += csd_xy + + +class _CohEst(_CohEstBase): + """Coherence Estimator.""" + + name = 'Coherence' + + def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + csd_mean = self._acc[con_idx] / n_epochs + self.con_scores[con_idx] = np.abs(csd_mean) / np.sqrt(psd_xx * psd_yy) + + +class _CohyEst(_CohEstBase): + """Coherency Estimator.""" + + name = 'Coherency' + + def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape, + dtype=np.complex128) + csd_mean = self._acc[con_idx] / n_epochs + self.con_scores[con_idx] = csd_mean / np.sqrt(psd_xx * psd_yy) + + +class _ImCohEst(_CohEstBase): + """Imaginary Coherence Estimator.""" + + name = 'Imaginary Coherence' + + def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + csd_mean = self._acc[con_idx] / n_epochs + self.con_scores[con_idx] = np.imag(csd_mean) / np.sqrt(psd_xx * psd_yy) + + +class _MultivariateCohEstBase(_EpochMeanMultivariateConEstBase): + """Base estimator for multivariate imag. part of coherency methods. + + See Ewald et al. (2012). NeuroImage. DOI: 10.1016/j.neuroimage.2011.11.084 + for equation references. + """ + + name = None + accumulate_psd = False + + def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): + super(_MultivariateCohEstBase, self).__init__( + n_signals, n_cons, n_freqs, n_times, n_jobs) + + def compute_con(self, indices, ranks, n_epochs=1): + """Compute multivariate imag. part of coherency between signals.""" + assert self.name in ['MIC', 'MIM'], ( + 'the class name is not recognised, please contact the ' + 'mne-connectivity developers') + + csd = self.reshape_csd() / n_epochs + n_times = csd.shape[0] + times = np.arange(n_times) + freqs = np.arange(self.n_freqs) + + if self.name == 'MIC': + self.patterns = np.full( + (2, self.n_cons, indices[0].shape[1], self.n_freqs, n_times), + np.nan) + + con_i = 0 + for seed_idcs, target_idcs, seed_rank, target_rank in zip( + indices[0], indices[1], ranks[0], ranks[1]): + self._log_connection_number(con_i) + + seed_idcs = seed_idcs[seed_idcs != -1] + target_idcs = target_idcs[target_idcs != -1] + con_idcs = [*seed_idcs, *target_idcs] + + C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] + + # Eqs. 32 & 33 + C_bar, U_bar_aa, U_bar_bb = self._csd_svd( + C, seed_idcs, seed_rank, target_rank) + + # Eqs. 3 & 4 + E = self._compute_e(C_bar, n_seeds=U_bar_aa.shape[3]) + + if self.name == 'MIC': + self._compute_mic(E, C, seed_idcs, target_idcs, n_times, + U_bar_aa, U_bar_bb, con_i) + else: + self._compute_mim(E, seed_idcs, target_idcs, con_i) + + con_i += 1 + + self.reshape_results() + + def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): + """Dimensionality reduction of CSD with SVD.""" + n_times = csd.shape[0] + n_seeds = len(seed_idcs) + n_targets = csd.shape[3] - n_seeds + + C_aa = csd[..., :n_seeds, :n_seeds] + C_ab = csd[..., :n_seeds, n_seeds:] + C_bb = csd[..., n_seeds:, n_seeds:] + C_ba = csd[..., n_seeds:, :n_seeds] + + # Eq. 32 + if seed_rank != n_seeds: + U_aa = np.linalg.svd(np.real(C_aa), full_matrices=False)[0] + U_bar_aa = U_aa[..., :seed_rank] + else: + U_bar_aa = np.broadcast_to( + np.identity(n_seeds), + (n_times, self.n_freqs) + (n_seeds, n_seeds)) + + if target_rank != n_targets: + U_bb = np.linalg.svd(np.real(C_bb), full_matrices=False)[0] + U_bar_bb = U_bb[..., :target_rank] + else: + U_bar_bb = np.broadcast_to( + np.identity(n_targets), + (n_times, self.n_freqs) + (n_targets, n_targets)) + + # Eq. 33 + C_bar_aa = np.matmul( + U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_aa, U_bar_aa)) + C_bar_ab = np.matmul( + U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_ab, U_bar_bb)) + C_bar_bb = np.matmul( + U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_bb, U_bar_bb)) + C_bar_ba = np.matmul( + U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_ba, U_bar_aa)) + C_bar = np.append(np.append(C_bar_aa, C_bar_ab, axis=3), + np.append(C_bar_ba, C_bar_bb, axis=3), axis=2) + + return C_bar, U_bar_aa, U_bar_bb + + def _compute_e(self, csd, n_seeds): + """Compute E from the CSD.""" + C_r = np.real(csd) + + parallel, parallel_compute_t, _ = parallel_func( + _mic_mim_compute_t, self.n_jobs, verbose=False) + + # imag. part of T filled when data is rank-deficient + T = np.zeros(csd.shape, dtype=np.complex128) + for block_i in ProgressBar( + range(self.n_steps), mesg="frequency blocks"): + freqs = self._get_block_indices(block_i, self.n_freqs) + T[:, freqs] = np.array(parallel(parallel_compute_t( + C_r[:, f], T[:, f], n_seeds) for f in freqs) + ).transpose(1, 0, 2, 3) + + if not np.isreal(T).all() or not np.isfinite(T).all(): + raise RuntimeError( + 'the transformation matrix of the data must be real-valued ' + 'and contain no NaN or infinity values; check that you are ' + 'using full rank data or specify an appropriate rank for the ' + 'seeds and targets that is less than or equal to their ranks') + T = np.real(T) # make T real if check passes + + # Eq. 4 + D = np.matmul(T, np.matmul(csd, T)) + + # E as imag. part of D between seeds and targets + return np.imag(D[..., :n_seeds, n_seeds:]) + + def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, + U_bar_bb, con_i): + """Compute MIC and the associated spatial patterns.""" + n_seeds = len(seed_idcs) + n_targets = len(target_idcs) + times = np.arange(n_times) + freqs = np.arange(self.n_freqs) + + # Eigendecomp. to find spatial filters for seeds and targets + w_seeds, V_seeds = np.linalg.eigh( + np.matmul(E, E.transpose(0, 1, 3, 2))) + w_targets, V_targets = np.linalg.eigh( + np.matmul(E.transpose(0, 1, 3, 2), E)) + if ( + len(seed_idcs) == len(target_idcs) and + np.all(np.sort(seed_idcs) == np.sort(target_idcs)) + ): + # strange edge-case where the eigenvectors returned should be a set + # of identity matrices with one rotated by 90 degrees, but are + # instead identical (i.e. are not rotated versions of one another). + # This leads to the case where the spatial filters are incorrectly + # applied, resulting in connectivity estimates of ~0 when they + # should be perfectly correlated ~1. Accordingly, we manually + # create a set of rotated identity matrices to use as the filters. + create_filter = False + stop = False + while not create_filter and not stop: + for time_i in range(n_times): + for freq_i in range(self.n_freqs): + if np.all(V_seeds[time_i, freq_i] == + V_targets[time_i, freq_i]): + create_filter = True + break + stop = True + if create_filter: + n_chans = E.shape[2] + eye_4d = np.zeros_like(V_seeds) + eye_4d[:, :, np.arange(n_chans), np.arange(n_chans)] = 1 + V_seeds = eye_4d + V_targets = np.rot90(eye_4d, axes=(2, 3)) + + # Spatial filters with largest eigval. for seeds and targets + alpha = V_seeds[times[:, None], freqs, :, w_seeds.argmax(axis=2)] + beta = V_targets[times[:, None], freqs, :, w_targets.argmax(axis=2)] + + # Eq. 46 (seed spatial patterns) + self.patterns[0, con_i, :n_seeds] = (np.matmul( + np.real(C[..., :n_seeds, :n_seeds]), + np.matmul(U_bar_aa, np.expand_dims(alpha, axis=3))))[..., 0].T + + # Eq. 47 (target spatial patterns) + self.patterns[1, con_i, :n_targets] = (np.matmul( + np.real(C[..., n_seeds:, n_seeds:]), + np.matmul(U_bar_bb, np.expand_dims(beta, axis=3))))[..., 0].T + + # Eq. 7 + self.con_scores[con_i] = (np.einsum( + 'ijk,ijk->ij', alpha, np.matmul(E, np.expand_dims( + beta, axis=3))[..., 0] + ) / np.linalg.norm(alpha, axis=2) * np.linalg.norm(beta, axis=2)).T + + def _compute_mim(self, E, seed_idcs, target_idcs, con_i): + """Compute MIM (a.k.a. GIM if seeds == targets).""" + # Eq. 14 + self.con_scores[con_i] = np.matmul( + E, E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T + + # Eq. 15 + if ( + len(seed_idcs) == len(target_idcs) and + np.all(np.sort(seed_idcs) == np.sort(target_idcs)) + ): + self.con_scores[con_i] *= 0.5 + + def reshape_results(self): + """Remove time dimension from results, if necessary.""" + if self.n_times == 0: + self.con_scores = self.con_scores[..., 0] + if self.patterns is not None: + self.patterns = self.patterns[..., 0] + + +def _mic_mim_compute_t(C, T, n_seeds): + """Compute T for a single frequency (used for MIC and MIM).""" + for time_i in range(C.shape[0]): + T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power( + C[time_i, :n_seeds, :n_seeds], -0.5 + ) + T[time_i, n_seeds:, n_seeds:] = sp.linalg.fractional_matrix_power( + C[time_i, n_seeds:, n_seeds:], -0.5 + ) + + return T + + +class _MICEst(_MultivariateCohEstBase): + """Multivariate imaginary part of coherency (MIC) estimator.""" + + name = "MIC" + + +class _MIMEst(_MultivariateCohEstBase): + """Multivariate interaction measure (MIM) estimator.""" + + name = "MIM" + + +class _PLVEst(_EpochMeanConEstBase): + """PLV Estimator.""" + + name = 'PLV' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_PLVEst, self).__init__(n_cons, n_freqs, n_times) + + # allocate accumulator + self._acc = np.zeros(self.csd_shape, dtype=np.complex128) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + self._acc[con_idx] += csd_xy / np.abs(csd_xy) + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + plv = np.abs(self._acc / n_epochs) + self.con_scores[con_idx] = plv + + +class _ciPLVEst(_EpochMeanConEstBase): + """corrected imaginary PLV Estimator.""" + + name = 'ciPLV' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_ciPLVEst, self).__init__(n_cons, n_freqs, n_times) + + # allocate accumulator + self._acc = np.zeros(self.csd_shape, dtype=np.complex128) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + self._acc[con_idx] += csd_xy / np.abs(csd_xy) + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + imag_plv = np.abs(np.imag(self._acc)) / n_epochs + real_plv = np.real(self._acc) / n_epochs + real_plv = np.clip(real_plv, -1, 1) # bounded from -1 to 1 + mask = (np.abs(real_plv) == 1) # avoid division by 0 + real_plv[mask] = 0 + corrected_imag_plv = imag_plv / np.sqrt(1 - real_plv ** 2) + self.con_scores[con_idx] = corrected_imag_plv + + +class _PLIEst(_EpochMeanConEstBase): + """PLI Estimator.""" + + name = 'PLI' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_PLIEst, self).__init__(n_cons, n_freqs, n_times) + + # allocate accumulator + self._acc = np.zeros(self.csd_shape) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + self._acc[con_idx] += np.sign(np.imag(csd_xy)) + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + pli_mean = self._acc[con_idx] / n_epochs + self.con_scores[con_idx] = np.abs(pli_mean) + + +class _PLIUnbiasedEst(_PLIEst): + """Unbiased PLI Square Estimator.""" + + name = 'Unbiased PLI Square' + accumulate_psd = False + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + pli_mean = self._acc[con_idx] / n_epochs + + # See Vinck paper Eq. (30) + con = (n_epochs * pli_mean ** 2 - 1) / (n_epochs - 1) + + self.con_scores[con_idx] = con + + +class _DPLIEst(_EpochMeanConEstBase): + """DPLI Estimator.""" + + name = 'DPLI' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_DPLIEst, self).__init__(n_cons, n_freqs, n_times) + + # allocate accumulator + self._acc = np.zeros(self.csd_shape) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + self._acc[con_idx] += np.heaviside(np.imag(csd_xy), 0.5) + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + + con = self._acc[con_idx] / n_epochs + + self.con_scores[con_idx] = con + + +class _WPLIEst(_EpochMeanConEstBase): + """WPLI Estimator.""" + + name = 'WPLI' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_WPLIEst, self).__init__(n_cons, n_freqs, n_times) + + # store both imag(csd) and abs(imag(csd)) + acc_shape = (2,) + self.csd_shape + self._acc = np.zeros(acc_shape) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + im_csd = np.imag(csd_xy) + self._acc[0, con_idx] += im_csd + self._acc[1, con_idx] += np.abs(im_csd) + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + + num = np.abs(self._acc[0, con_idx]) + denom = self._acc[1, con_idx] + + # handle zeros in denominator + z_denom = np.where(denom == 0.) + denom[z_denom] = 1. + + con = num / denom + + # where we had zeros in denominator, we set con to zero + con[z_denom] = 0. + + self.con_scores[con_idx] = con + + +class _WPLIDebiasedEst(_EpochMeanConEstBase): + """Debiased WPLI Square Estimator.""" + + name = 'Debiased WPLI Square' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_WPLIDebiasedEst, self).__init__(n_cons, n_freqs, n_times) + # store imag(csd), abs(imag(csd)), imag(csd)^2 + acc_shape = (3,) + self.csd_shape + self._acc = np.zeros(acc_shape) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + im_csd = np.imag(csd_xy) + self._acc[0, con_idx] += im_csd + self._acc[1, con_idx] += np.abs(im_csd) + self._acc[2, con_idx] += im_csd ** 2 + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + + # note: we use the trick from fieldtrip to compute the + # the estimate over all pairwise epoch combinations + sum_im_csd = self._acc[0, con_idx] + sum_abs_im_csd = self._acc[1, con_idx] + sum_sq_im_csd = self._acc[2, con_idx] + + denom = sum_abs_im_csd ** 2 - sum_sq_im_csd + + # handle zeros in denominator + z_denom = np.where(denom == 0.) + denom[z_denom] = 1. + + con = (sum_im_csd ** 2 - sum_sq_im_csd) / denom + + # where we had zeros in denominator, we set con to zero + con[z_denom] = 0. + + self.con_scores[con_idx] = con + + +class _PPCEst(_EpochMeanConEstBase): + """Pairwise Phase Consistency (PPC) Estimator.""" + + name = 'PPC' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_PPCEst, self).__init__(n_cons, n_freqs, n_times) + + # store csd / abs(csd) + self._acc = np.zeros(self.csd_shape, dtype=np.complex128) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + denom = np.abs(csd_xy) + z_denom = np.where(denom == 0.) + denom[z_denom] = 1. + this_acc = csd_xy / denom + this_acc[z_denom] = 0. # handle division by zero + + self._acc[con_idx] += this_acc + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + + # note: we use the trick from fieldtrip to compute the + # the estimate over all pairwise epoch combinations + con = ((self._acc[con_idx] * np.conj(self._acc[con_idx]) - n_epochs) / + (n_epochs * (n_epochs - 1.))) + + self.con_scores[con_idx] = np.real(con) + + +class _GCEstBase(_EpochMeanMultivariateConEstBase): + """Base multivariate state-space Granger causality estimator.""" + + accumulate_psd = False + + def __init__(self, n_signals, n_cons, n_freqs, n_times, n_lags, n_jobs=1): + super(_GCEstBase, self).__init__( + n_signals, n_cons, n_freqs, n_times, n_jobs) + + self.freq_res = (self.n_freqs - 1) * 2 + if n_lags >= self.freq_res: + raise ValueError( + 'the number of lags (%i) must be less than double the ' + 'frequency resolution (%i)' % (n_lags, self.freq_res, )) + self.n_lags = n_lags + + def compute_con(self, indices, ranks, n_epochs=1): + """Compute multivariate state-space Granger causality.""" + assert self.name in ['GC', 'GC time-reversed'], ( + 'the class name is not recognised, please contact the ' + 'mne-connectivity developers') + + csd = self.reshape_csd() / n_epochs + + n_times = csd.shape[0] + times = np.arange(n_times) + freqs = np.arange(self.n_freqs) + + con_i = 0 + for seed_idcs, target_idcs, seed_rank, target_rank in zip( + indices[0], indices[1], ranks[0], ranks[1]): + self._log_connection_number(con_i) + + seed_idcs = seed_idcs[seed_idcs != -1] + target_idcs = target_idcs[target_idcs != -1] + con_idcs = [*seed_idcs, *target_idcs] + + C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] + + C_bar = self._csd_svd(C, seed_idcs, seed_rank, target_rank) + n_signals = seed_rank + target_rank + con_seeds = np.arange(seed_rank) + con_targets = np.arange(target_rank) + seed_rank + + autocov = self._compute_autocov(C_bar) + if self.name == "GC time-reversed": + autocov = autocov.transpose(0, 1, 3, 2) + + A_f, V = self._autocov_to_full_var(autocov) + A_f_3d = np.reshape( + A_f, (n_times, n_signals, n_signals * self.n_lags), + order="F") + A, K = self._full_var_to_iss(A_f_3d) + + self.con_scores[con_i] = self._iss_to_ugc( + A, A_f_3d, K, V, con_seeds, con_targets) + + con_i += 1 + + self.reshape_results() + + def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): + """Dimensionality reduction of CSD with SVD on the covariance.""" + # sum over times and epochs to get cov. from CSD + cov = csd.sum(axis=(0, 1)) + + n_seeds = len(seed_idcs) + n_targets = csd.shape[3] - n_seeds + + cov_aa = cov[:n_seeds, :n_seeds] + cov_bb = cov[n_seeds:, n_seeds:] + + if seed_rank != n_seeds: + U_aa = np.linalg.svd(np.real(cov_aa), full_matrices=False)[0] + U_bar_aa = U_aa[:, :seed_rank] + else: + U_bar_aa = np.identity(n_seeds) + + if target_rank != n_targets: + U_bb = np.linalg.svd(np.real(cov_bb), full_matrices=False)[0] + U_bar_bb = U_bb[:, :target_rank] + else: + U_bar_bb = np.identity(n_targets) + + C_aa = csd[..., :n_seeds, :n_seeds] + C_ab = csd[..., :n_seeds, n_seeds:] + C_bb = csd[..., n_seeds:, n_seeds:] + C_ba = csd[..., n_seeds:, :n_seeds] + + C_bar_aa = np.matmul( + U_bar_aa.transpose(1, 0), np.matmul(C_aa, U_bar_aa)) + C_bar_ab = np.matmul( + U_bar_aa.transpose(1, 0), np.matmul(C_ab, U_bar_bb)) + C_bar_bb = np.matmul( + U_bar_bb.transpose(1, 0), np.matmul(C_bb, U_bar_bb)) + C_bar_ba = np.matmul( + U_bar_bb.transpose(1, 0), np.matmul(C_ba, U_bar_aa)) + C_bar = np.append(np.append(C_bar_aa, C_bar_ab, axis=3), + np.append(C_bar_ba, C_bar_bb, axis=3), axis=2) + + return C_bar + + def _compute_autocov(self, csd): + """Compute autocovariance from the CSD.""" + n_times = csd.shape[0] + n_signals = csd.shape[2] + + circular_shifted_csd = np.concatenate( + [np.flip(np.conj(csd[:, 1:]), axis=1), csd[:, :-1]], axis=1) + ifft_shifted_csd = self._block_ifft( + circular_shifted_csd, self.freq_res) + lags_ifft_shifted_csd = np.reshape( + ifft_shifted_csd[:, :self.n_lags + 1], + (n_times, self.n_lags + 1, n_signals ** 2), order="F") + + signs = np.repeat([1], self.n_lags + 1).tolist() + signs[1::2] = [x * -1 for x in signs[1::2]] + sign_matrix = np.repeat( + np.tile(np.array(signs), (n_signals ** 2, 1))[np.newaxis], + n_times, axis=0).transpose(0, 2, 1) + + return np.real(np.reshape( + sign_matrix * lags_ifft_shifted_csd, + (n_times, self.n_lags + 1, n_signals, n_signals), order="F")) + + def _block_ifft(self, csd, n_points): + """Compute block iFFT with n points.""" + shape = csd.shape + csd_3d = np.reshape( + csd, (shape[0], shape[1], shape[2] * shape[3]), order="F") + + csd_ifft = np.fft.ifft(csd_3d, n=n_points, axis=1) + + return np.reshape(csd_ifft, shape, order="F") + + def _autocov_to_full_var(self, autocov): + """Compute full VAR model using Whittle's LWR recursion.""" + if np.any(np.linalg.det(autocov) == 0): + raise RuntimeError( + 'the autocovariance matrix is singular; check if your data is ' + 'rank deficient and specify an appropriate rank argument <= ' + 'the rank of the seeds and targets') + + A_f, V = self._whittle_lwr_recursion(autocov) + + if not np.isfinite(A_f).all(): + raise RuntimeError('at least one VAR model coefficient is ' + 'infinite or NaN; check the data you are using') + + try: + np.linalg.cholesky(V) + except np.linalg.LinAlgError as np_error: + raise RuntimeError( + 'the covariance matrix of the residuals is not ' + 'positive-definite; check the singular values of your data ' + 'and specify an appropriate rank argument <= the rank of the ' + 'seeds and targets') from np_error + + return A_f, V + + def _whittle_lwr_recursion(self, G): + """Solve Yule-Walker eqs. for full VAR params. with LWR recursion. + + See: Whittle P., 1963. Biometrika, DOI: 10.1093/biomet/50.1-2.129 + """ + # Initialise recursion + n = G.shape[2] # number of signals + q = G.shape[1] - 1 # number of lags + t = G.shape[0] # number of times + qn = n * q + + cov = G[:, 0, :, :] # covariance + G_f = np.reshape( + G[:, 1:, :, :].transpose(0, 3, 1, 2), (t, qn, n), + order="F") # forward autocov + G_b = np.reshape( + np.flip(G[:, 1:, :, :], 1).transpose(0, 3, 2, 1), (t, n, qn), + order="F").transpose(0, 2, 1) # backward autocov + + A_f = np.zeros((t, n, qn)) # forward coefficients + A_b = np.zeros((t, n, qn)) # backward coefficients + + k = 1 # model order + r = q - k + k_f = np.arange(k * n) # forward indices + k_b = np.arange(r * n, qn) # backward indices + + try: + A_f[:, :, k_f] = np.linalg.solve( + cov, G_b[:, k_b, :].transpose(0, 2, 1)).transpose(0, 2, 1) + A_b[:, :, k_b] = np.linalg.solve( + cov, G_f[:, k_f, :].transpose(0, 2, 1)).transpose(0, 2, 1) + + # Perform recursion + for k in np.arange(2, q + 1): + var_A = (G_b[:, (r - 1) * n: r * n, :] - + np.matmul(A_f[:, :, k_f], G_b[:, k_b, :])) + var_B = cov - np.matmul(A_b[:, :, k_b], G_b[:, k_b, :]) + AA_f = np.linalg.solve( + var_B, var_A.transpose(0, 2, 1)).transpose(0, 2, 1) + + var_A = (G_f[:, (k - 1) * n: k * n, :] - + np.matmul(A_b[:, :, k_b], G_f[:, k_f, :])) + var_B = cov - np.matmul(A_f[:, :, k_f], G_f[:, k_f, :]) + AA_b = np.linalg.solve( + var_B, var_A.transpose(0, 2, 1)).transpose(0, 2, 1) + + A_f_previous = A_f[:, :, k_f] + A_b_previous = A_b[:, :, k_b] + + r = q - k + k_f = np.arange(k * n) + k_b = np.arange(r * n, qn) + + A_f[:, :, k_f] = np.dstack( + (A_f_previous - np.matmul(AA_f, A_b_previous), AA_f)) + A_b[:, :, k_b] = np.dstack( + (AA_b, A_b_previous - np.matmul(AA_b, A_f_previous))) + except np.linalg.LinAlgError as np_error: + raise RuntimeError( + 'the autocovariance matrix is singular; check if your data is ' + 'rank deficient and specify an appropriate rank argument <= ' + 'the rank of the seeds and targets') from np_error + + V = cov - np.matmul(A_f, G_f) + A_f = np.reshape(A_f, (t, n, n, q), order="F") + + return A_f, V + + def _full_var_to_iss(self, A_f): + """Compute innovations-form parameters for a state-space model. + + Parameters computed from a full VAR model using Aoki's method. For a + non-moving-average full VAR model, the state-space parameter C + (observation matrix) is identical to AF of the VAR model. + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101. + """ + t = A_f.shape[0] + m = A_f.shape[1] # number of signals + p = A_f.shape[2] // m # number of autoregressive lags + + I_p = np.dstack(t * [np.eye(m * p)]).transpose(2, 0, 1) + A = np.hstack((A_f, I_p[:, : (m * p - m), :])) # state transition + # matrix + K = np.hstack(( + np.dstack(t * [np.eye(m)]).transpose(2, 0, 1), + np.zeros((t, (m * (p - 1)), m)))) # Kalman gain matrix + + return A, K + + def _iss_to_ugc(self, A, C, K, V, seeds, targets): + """Compute unconditional GC from innovations-form state-space params. + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101. + """ + times = np.arange(A.shape[0]) + freqs = np.arange(self.n_freqs) + z = np.exp(-1j * np.pi * np.linspace(0, 1, self.n_freqs)) # points + # on a unit circle in the complex plane, one for each frequency + + H = self._iss_to_tf(A, C, K, z) # spectral transfer function + V_22_1 = np.linalg.cholesky(self._partial_covar(V, seeds, targets)) + HV = np.matmul(H, np.linalg.cholesky(V)) + S = np.matmul(HV, HV.conj().transpose(0, 1, 3, 2)) # Eq. 6 + S_11 = S[np.ix_(freqs, times, targets, targets)] + HV_12 = np.matmul(H[np.ix_(freqs, times, targets, seeds)], V_22_1) + HVH = np.matmul(HV_12, HV_12.conj().transpose(0, 1, 3, 2)) + + # Eq. 11 + return np.real( + np.log(np.linalg.det(S_11)) - np.log(np.linalg.det(S_11 - HVH))) + + def _iss_to_tf(self, A, C, K, z): + """Compute transfer function for innovations-form state-space params. + + In the frequency domain, the back-shift operator, z, is a vector of + points on a unit circle in the complex plane. z = e^-iw, where -pi < w + <= pi. + + A note on efficiency: solving over the 4D time-freq. tensor is slower + than looping over times and freqs when n_times and n_freqs high, and + when n_times and n_freqs low, looping over times and freqs very fast + anyway (plus tensor solving doesn't allow for parallelisation). + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101. + """ + t = A.shape[0] + h = self.n_freqs + n = C.shape[1] + m = A.shape[1] + I_n = np.eye(n) + I_m = np.eye(m) + H = np.zeros((h, t, n, n), dtype=np.complex128) + + parallel, parallel_compute_H, _ = parallel_func( + _gc_compute_H, self.n_jobs, verbose=False + ) + H = np.zeros((h, t, n, n), dtype=np.complex128) + for block_i in ProgressBar( + range(self.n_steps), mesg="frequency blocks" + ): + freqs = self._get_block_indices(block_i, self.n_freqs) + H[freqs] = parallel( + parallel_compute_H(A, C, K, z[k], I_n, I_m) for k in freqs) + + return H + + def _partial_covar(self, V, seeds, targets): + """Compute partial covariance of a matrix. + + Given a covariance matrix V, the partial covariance matrix of V between + indices i and j, given k (V_ij|k), is equivalent to V_ij - V_ik * + V_kk^-1 * V_kj. In this case, i and j are seeds, and k are targets. + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101. + """ + times = np.arange(V.shape[0]) + W = np.linalg.solve( + np.linalg.cholesky(V[np.ix_(times, targets, targets)]), + V[np.ix_(times, targets, seeds)], + ) + W = np.matmul(W.transpose(0, 2, 1), W) + + return V[np.ix_(times, seeds, seeds)] - W + + def reshape_results(self): + """Remove time dimension from con. scores, if necessary.""" + if self.n_times == 0: + self.con_scores = self.con_scores[:, :, 0] + + +def _gc_compute_H(A, C, K, z_k, I_n, I_m): + """Compute transfer function for innovations-form state-space params. + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101, Eq. 4. + """ + from scipy import linalg # XXX: is this necessary??? + H = np.zeros((A.shape[0], C.shape[1], C.shape[1]), dtype=np.complex128) + for t in range(A.shape[0]): + H[t] = I_n + np.matmul( + C[t], linalg.lu_solve(linalg.lu_factor(z_k * I_m - A[t]), K[t])) + + return H + + +class _GCEst(_GCEstBase): + """[seeds -> targets] state-space GC estimator.""" + + name = "GC" + + +class _GCTREst(_GCEstBase): + """time-reversed[seeds -> targets] state-space GC estimator.""" + + name = "GC time-reversed" + ############################################################################### +_multivariate_methods = ['mic', 'mim', 'gc', 'gc_tr'] _gc_methods = ['gc', 'gc_tr'] @@ -254,9 +1292,9 @@ def _epoch_spectral_connectivity(data, sig_idx, tmin_idx, tmax_idx, sfreq, block_size, psd, accumulate_psd, con_method_types, con_methods, n_signals, n_signals_use, n_times, gc_n_lags, - multivariate_con, accumulate_inplace=True): + accumulate_inplace=True): """Estimate connectivity for one epoch (see spectral_connectivity).""" - if multivariate_con: + if any(this_method in _multivariate_methods for this_method in method): n_con_signals = n_signals_use ** 2 else: n_con_signals = n_cons @@ -273,7 +1311,8 @@ def _epoch_spectral_connectivity(data, sig_idx, tmin_idx, tmax_idx, sfreq, con_methods = [] for mtype in con_method_types: method_params = list(inspect.signature(mtype).parameters) - if multivariate_con: + if "n_signals" in method_params: + # if it's a multivariate connectivity method if "n_lags" in method_params: # if it's a Granger causality method con_methods.append( @@ -462,12 +1501,22 @@ def _get_and_verify_data_sizes(data, sfreq, n_signals=None, n_times=None, return n_signals, n_times, times, warn_times -def _check_estimators(method, con_method_map): +# map names to estimator types +_CON_METHOD_MAP = {'coh': _CohEst, 'cohy': _CohyEst, 'imcoh': _ImCohEst, + 'plv': _PLVEst, 'ciplv': _ciPLVEst, 'ppc': _PPCEst, + 'pli': _PLIEst, 'pli2_unbiased': _PLIUnbiasedEst, + 'dpli': _DPLIEst, 'wpli': _WPLIEst, + 'wpli2_debiased': _WPLIDebiasedEst, 'mic': _MICEst, + 'mim': _MIMEst, 'gc': _GCEst, 'gc_tr': _GCTREst} + + +def _check_estimators(method): """Check construction of connectivity estimators.""" + n_methods = len(method) con_method_types = list() for this_method in method: - if this_method in con_method_map: - con_method_types.append(con_method_map[this_method]) + if this_method in _CON_METHOD_MAP: + con_method_types.append(_CON_METHOD_MAP[this_method]) elif isinstance(this_method, str): raise ValueError('%s is not a valid connectivity method' % this_method) @@ -483,18 +1532,290 @@ def _check_estimators(method, con_method_map): accumulate_psd = any( this_method.accumulate_psd for this_method in con_method_types) - return con_method_types, accumulate_psd + return con_method_types, n_methods, accumulate_psd + + +@ verbose +@ fill_doc +def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, + sfreq=None, + mode='multitaper', fmin=None, fmax=np.inf, + fskip=0, faverage=False, tmin=None, tmax=None, + mt_bandwidth=None, mt_adaptive=False, + mt_low_bias=True, cwt_freqs=None, + cwt_n_cycles=7, gc_n_lags=40, rank=None, + block_size=1000, n_jobs=1, verbose=None): + r"""Compute frequency- and time-frequency-domain connectivity measures. + + The connectivity method(s) are specified using the "method" parameter. + All methods are based on estimates of the cross- and power spectral + densities (CSD/PSD) Sxy and Sxx, Syy. + + Parameters + ---------- + data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs + The data from which to compute connectivity. Note that it is also + possible to combine multiple signals by providing a list of tuples, + e.g., data = [(arr_0, stc_0), (arr_1, stc_1), (arr_2, stc_2)], + corresponds to 3 epochs, and arr_* could be an array with the same + number of time points as stc_*. The array-like object can also + be a list/generator of array, shape =(n_signals, n_times), + or a list/generator of SourceEstimate or VolSourceEstimate objects. + %(names)s + method : str | list of str + Connectivity measure(s) to compute. These can be ``['coh', 'cohy', + 'imcoh', 'mic', 'mim', 'plv', 'ciplv', 'ppc', 'pli', 'dpli', 'wpli', + 'wpli2_debiased', 'gc', 'gc_tr']``. Multivariate methods (``['mic', + 'mim', 'gc', 'gc_tr]``) cannot be called with the other methods. + indices : tuple of array | None + Two arrays with indices of connections for which to compute + connectivity. If a bivariate method is called, each array for the seeds + and targets should contain the channel indices for each bivariate + connection. If a multivariate method is called, each array for the + seeds and targets should consist of nested arrays containing + the channel indices for each multivariate connection. If ``None``, + connections between all channels are computed, unless a Granger + causality method is called, in which case an error is raised. + sfreq : float + The sampling frequency. Required if data is not + :class:`Epochs `. + mode : str + Spectrum estimation mode can be either: 'multitaper', 'fourier', or + 'cwt_morlet'. + fmin : float | tuple of float + The lower frequency of interest. Multiple bands are defined using + a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. + fmax : float | tuple of float + The upper frequency of interest. Multiple bands are dedined using + a tuple, e.g. (13., 30.) for two band with 13Hz and 30Hz upper freq. + fskip : int + Omit every "(fskip + 1)-th" frequency bin to decimate in frequency + domain. + faverage : bool + Average connectivity scores for each frequency band. If True, + the output freqs will be a list with arrays of the frequencies + that were averaged. + tmin : float | None + Time to start connectivity estimation. Note: when "data" is an array, + the first sample is assumed to be at time 0. For other types + (Epochs, etc.), the time information contained in the object is used + to compute the time indices. + tmax : float | None + Time to end connectivity estimation. Note: when "data" is an array, + the first sample is assumed to be at time 0. For other types + (Epochs, etc.), the time information contained in the object is used + to compute the time indices. + mt_bandwidth : float | None + The bandwidth of the multitaper windowing function in Hz. + Only used in 'multitaper' mode. + mt_adaptive : bool + Use adaptive weights to combine the tapered spectra into PSD. + Only used in 'multitaper' mode. + mt_low_bias : bool + Only use tapers with more than 90 percent spectral concentration + within bandwidth. Only used in 'multitaper' mode. + cwt_freqs : array + Array of frequencies of interest. Only used in 'cwt_morlet' mode. + cwt_n_cycles : float | array of float + Number of cycles. Fixed number or one per frequency. Only used in + 'cwt_morlet' mode. + gc_n_lags : int + Number of lags to use for the vector autoregressive model when + computing Granger causality. Higher values increase computational cost, + but reduce the degree of spectral smoothing in the results. Only used + if ``method`` contains any of ``['gc', 'gc_tr']``. + rank : tuple of array | None + Two arrays with the rank to project the seed and target data to, + respectively, using singular value decomposition. If None, the rank of + the data is computed and projected to. Only used if ``method`` contains + any of ``['mic', 'mim', 'gc', 'gc_tr']``. + block_size : int + How many connections to compute at once (higher numbers are faster + but require more memory). + n_jobs : int + How many samples to process in parallel. + %(verbose)s + + Returns + ------- + con : array | list of array + Computed connectivity measure(s). Either an instance of + ``SpectralConnectivity`` or ``SpectroTemporalConnectivity``. + The shape of the connectivity result will be: + + - ``(n_cons, n_freqs)`` for multitaper or fourier modes + - ``(n_cons, n_freqs, n_times)`` for cwt_morlet mode + - ``n_cons = n_signals ** 2`` for bivariate methods with + ``indices=None`` + - ``n_cons = 1`` for multivariate methods with ``indices=None`` + - ``n_cons = len(indices[0])`` for bivariate and multivariate methods + when indices is supplied. + + See Also + -------- + mne_connectivity.spectral_connectivity_time + mne_connectivity.SpectralConnectivity + mne_connectivity.SpectroTemporalConnectivity + + Notes + ----- + Please note that the interpretation of the measures in this function + depends on the data and underlying assumptions and does not necessarily + reflect a causal relationship between brain regions. + + These measures are not to be interpreted over time. Each Epoch passed into + the dataset is interpreted as an independent sample of the same + connectivity structure. Within each Epoch, it is assumed that the spectral + measure is stationary. The spectral measures implemented in this function + are computed across Epochs. **Thus, spectral measures computed with only + one Epoch will result in errorful values and spectral measures computed + with few Epochs will be unreliable.** Please see + ``spectral_connectivity_time`` for time-resolved connectivity estimation. + + The spectral densities can be estimated using a multitaper method with + digital prolate spheroidal sequence (DPSS) windows, a discrete Fourier + transform with Hanning windows, or a continuous wavelet transform using + Morlet wavelets. The spectral estimation mode is specified using the + "mode" parameter. + + By default, the connectivity between all signals is computed (only + connections corresponding to the lower-triangular part of the connectivity + matrix). If one is only interested in the connectivity between some + signals, the "indices" parameter can be used. For example, to compute the + connectivity between the signal with index 0 and signals "2, 3, 4" (a total + of 3 connections) one can use the following:: + + indices = (np.array([0, 0, 0]), # row indices + np.array([2, 3, 4])) # col indices + + con = spectral_connectivity_epochs(data, method='coh', + indices=indices, ...) + + In this case con.get_data().shape = (3, n_freqs). The connectivity scores + are in the same order as defined indices. + + For multivariate methods, this is handled differently. If "indices" is + None, connectivity between all signals will be computed and a single + connectivity spectrum will be returned (this is not possible if a Granger + causality method is called). If "indices" is specified, seed and target + indices for each connection should be specified as nested array-likes. For + example, to compute the connectivity between signals (0, 1) -> (2, 3) and + (0, 1) -> (4, 5), indices should be specified as:: + + indices = (np.array([[0, 1], [0, 1]]), # seeds + np.array([[2, 3], [4, 5]])) # targets + + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. + + **Supported Connectivity Measures** + + The connectivity method(s) is specified using the "method" parameter. The + following methods are supported (note: ``E[]`` denotes average over + epochs). Multiple measures can be computed at once by using a list/tuple, + e.g., ``['coh', 'pli']`` to compute coherence and PLI. + + 'coh' : Coherence given by:: + + | E[Sxy] | + C = --------------------- + sqrt(E[Sxx] * E[Syy]) + + 'cohy' : Coherency given by:: + + E[Sxy] + C = --------------------- + sqrt(E[Sxx] * E[Syy]) + + 'imcoh' : Imaginary coherence :footcite:`NolteEtAl2004` given by:: + + Im(E[Sxy]) + C = ---------------------- + sqrt(E[Sxx] * E[Syy]) + + 'mic' : Maximised Imaginary part of Coherency (MIC) + :footcite:`EwaldEtAl2012` given by: + + :math:`MIC=\Large{\frac{\boldsymbol{\alpha}^T \boldsymbol{E \beta}} + {\parallel\boldsymbol{\alpha}\parallel \parallel\boldsymbol{\beta} + \parallel}}` + + where: :math:`\boldsymbol{E}` is the imaginary part of the + transformed cross-spectral density between seeds and targets; and + :math:`\boldsymbol{\alpha}` and :math:`\boldsymbol{\beta}` are + eigenvectors for the seeds and targets, such that + :math:`\boldsymbol{\alpha}^T \boldsymbol{E \beta}` maximises + connectivity between the seeds and targets. + + 'mim' : Multivariate Interaction Measure (MIM) + :footcite:`EwaldEtAl2012` given by: + + :math:`MIM=tr(\boldsymbol{EE}^T)` + + 'plv' : Phase-Locking Value (PLV) :footcite:`LachauxEtAl1999` given + by:: + + PLV = |E[Sxy/|Sxy|]| + + 'ciplv' : corrected imaginary PLV (ciPLV) + :footcite:`BrunaEtAl2018` given by:: + + |E[Im(Sxy/|Sxy|)]| + ciPLV = ------------------------------------ + sqrt(1 - |E[real(Sxy/|Sxy|)]| ** 2) + + 'ppc' : Pairwise Phase Consistency (PPC), an unbiased estimator + of squared PLV :footcite:`VinckEtAl2010`. + + 'pli' : Phase Lag Index (PLI) :footcite:`StamEtAl2007` given by:: + + PLI = |E[sign(Im(Sxy))]| + + 'pli2_unbiased' : Unbiased estimator of squared PLI + :footcite:`VinckEtAl2011`. + + 'dpli' : Directed Phase Lag Index (DPLI) :footcite:`StamEtAl2012` + given by (where H is the Heaviside function):: + DPLI = E[H(Im(Sxy))] -def _check_spectral_connectivity_epochs_settings(method, fmin, fmax, n_jobs, - verbose, con_method_map): - """Check settings inputs for spectral_connectivity_epochs... functions.""" + 'wpli' : Weighted Phase Lag Index (WPLI) :footcite:`VinckEtAl2011` + given by:: + + |E[Im(Sxy)]| + WPLI = ------------------ + E[|Im(Sxy)|] + + 'wpli2_debiased' : Debiased estimator of squared WPLI + :footcite:`VinckEtAl2011`. + + 'gc' : State-space Granger Causality (GC) :footcite:`BarnettSeth2015` + given by: + + :math:`GC = ln\Large{(\frac{\lvert\boldsymbol{S}_{tt}\rvert}{\lvert + \boldsymbol{S}_{tt}-\boldsymbol{H}_{ts}\boldsymbol{\Sigma}_{ss + \lvert t}\boldsymbol{H}_{ts}^*\rvert}})`, + + where: :math:`s` and :math:`t` represent the seeds and targets, + respectively; :math:`\boldsymbol{H}` is the spectral transfer + function; :math:`\boldsymbol{\Sigma}` is the residuals matrix of + the autoregressive model; and :math:`\boldsymbol{S}` is + :math:`\boldsymbol{\Sigma}` transformed by :math:`\boldsymbol{H}`. + + 'gc_tr' : State-space GC on time-reversed signals + :footcite:`BarnettSeth2015,WinklerEtAl2016` given by the same equation + as for 'gc', but where the autocovariance sequence from which the + autoregressive model is produced is transposed to mimic the reversal of + the original signal in time. + + References + ---------- + .. footbibliography:: + """ if n_jobs != 1: parallel, my_epoch_spectral_connectivity, _ = parallel_func( _epoch_spectral_connectivity, n_jobs, verbose=verbose) - else: - parallel = None - my_epoch_spectral_connectivity = None # format fmin and fmax and check inputs if fmin is None: @@ -506,22 +1827,34 @@ def _check_spectral_connectivity_epochs_settings(method, fmin, fmax, n_jobs, raise ValueError('fmin and fmax must have the same length') if np.any(fmin > fmax): raise ValueError('fmax must be larger than fmin') + n_bands = len(fmin) # assign names to connectivity methods if not isinstance(method, (list, tuple)): method = [method] # make it a list so we can iterate over it - # handle connectivity estimators - con_method_types, accumulate_psd = _check_estimators(method, - con_method_map) - - return (fmin, fmax, n_bands, method, con_method_types, accumulate_psd, - parallel, my_epoch_spectral_connectivity) + if n_bands != 1 and any( + this_method in _gc_methods for this_method in method + ): + raise ValueError('computing Granger causality on multiple frequency ' + 'bands is not yet supported') + + if any(this_method in _multivariate_methods for this_method in method): + if not all(this_method in _multivariate_methods for + this_method in method): + raise ValueError( + 'bivariate and multivariate connectivity methods cannot be ' + 'used in the same function call') + multivariate_con = True + else: + multivariate_con = False + # handle connectivity estimators + (con_method_types, n_methods, accumulate_psd) = _check_estimators(method) -def _check_spectral_connectivity_epochs_data(data, sfreq, names): - """Check data inputs for spectral_connectivity_epochs... functions.""" + events = None + event_id = None if isinstance(data, BaseEpochs): names = data.ch_names times_in = data.times # input times for Epochs input type @@ -543,23 +1876,208 @@ def _check_spectral_connectivity_epochs_data(data, sfreq, names): data.add_annotations_to_metadata(overwrite=True) metadata = data.metadata else: - events = None - event_id = None times_in = None metadata = None if sfreq is None: raise ValueError('Sampling frequency (sfreq) is required with ' 'array input.') - return (names, times_in, sfreq, events, event_id, metadata) + # loop over data; it could be a generator that returns + # (n_signals x n_times) arrays or SourceEstimates + epoch_idx = 0 + logger.info('Connectivity computation...') + warn_times = True + for epoch_block in _get_n_epochs(data, n_jobs): + if epoch_idx == 0: + # initialize everything times and frequencies + (n_cons, times, n_times, times_in, n_times_in, tmin_idx, + tmax_idx, n_freqs, freq_mask, freqs, freqs_bands, freq_idx_bands, + n_signals, indices_use, warn_times) = _prepare_connectivity( + epoch_block=epoch_block, times_in=times_in, + tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, sfreq=sfreq, + indices=indices, method=method, mode=mode, fskip=fskip, + n_bands=n_bands, cwt_freqs=cwt_freqs, faverage=faverage) + + # check rank input and compute data ranks if necessary + if multivariate_con: + rank = _check_rank_input(rank, data, indices_use) + else: + rank = None + gc_n_lags = None + + # make sure padded indices are stored in the connectivity object + if multivariate_con and indices is not None: + indices = tuple(np.array(indices_use)) # create a copy + + # get the window function, wavelets, etc for different modes + (spectral_params, mt_adaptive, n_times_spectrum, + n_tapers) = _assemble_spectral_params( + mode=mode, n_times=n_times, mt_adaptive=mt_adaptive, + mt_bandwidth=mt_bandwidth, sfreq=sfreq, + mt_low_bias=mt_low_bias, cwt_n_cycles=cwt_n_cycles, + cwt_freqs=cwt_freqs, freqs=freqs, freq_mask=freq_mask) + + # unique signals for which we actually need to compute PSD etc. + if multivariate_con: + sig_idx = np.unique(np.concatenate(np.concatenate( + indices_use))) + sig_idx = sig_idx[sig_idx != -1] + remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(sig_idx)} + remapping[-1] = -1 + remapped_inds = (indices_use[0].copy(), indices_use[1].copy()) + con_i = 0 + for seed, target in zip(indices_use[0], indices_use[1]): + remapped_inds[0][con_i] = np.array([ + remapping[idx] for idx in seed]) + remapped_inds[1][con_i] = np.array([ + remapping[idx] for idx in target]) + con_i += 1 + remapped_sig = [remapping[idx] for idx in sig_idx] + else: + sig_idx = np.unique(np.r_[indices_use[0], indices_use[1]]) + n_signals_use = len(sig_idx) + + # map indices to unique indices + if multivariate_con: + indices_use = remapped_inds # use remapped seeds & targets + idx_map = [np.sort(np.repeat(remapped_sig, len(sig_idx))), + np.tile(remapped_sig, len(sig_idx))] + else: + idx_map = [ + np.searchsorted(sig_idx, ind) for ind in indices_use] + # allocate space to accumulate PSD + if accumulate_psd: + if n_times_spectrum == 0: + psd_shape = (n_signals_use, n_freqs) + else: + psd_shape = (n_signals_use, n_freqs, n_times_spectrum) + psd = np.zeros(psd_shape) + else: + psd = None + + # create instances of the connectivity estimators + con_methods = [] + for mtype_i, mtype in enumerate(con_method_types): + method_params = dict(n_cons=n_cons, n_freqs=n_freqs, + n_times=n_times_spectrum) + if method[mtype_i] in _multivariate_methods: + method_params.update(dict(n_signals=n_signals_use)) + if method[mtype_i] in _gc_methods: + method_params.update(dict(n_lags=gc_n_lags)) + con_methods.append(mtype(**method_params)) + + sep = ', ' + metrics_str = sep.join([meth.name for meth in con_methods]) + logger.info(' the following metrics will be computed: %s' + % metrics_str) + + # check dimensions and time scale + for this_epoch in epoch_block: + _, _, _, warn_times = _get_and_verify_data_sizes( + this_epoch, sfreq, n_signals, n_times_in, times_in, + warn_times=warn_times) + + call_params = dict( + sig_idx=sig_idx, tmin_idx=tmin_idx, tmax_idx=tmax_idx, sfreq=sfreq, + method=method, mode=mode, freq_mask=freq_mask, idx_map=idx_map, + n_cons=n_cons, block_size=block_size, + psd=psd, accumulate_psd=accumulate_psd, + mt_adaptive=mt_adaptive, + con_method_types=con_method_types, + con_methods=con_methods if n_jobs == 1 else None, + n_signals=n_signals, n_signals_use=n_signals_use, n_times=n_times, + gc_n_lags=gc_n_lags, + accumulate_inplace=True if n_jobs == 1 else False) + call_params.update(**spectral_params) + + if n_jobs == 1: + # no parallel processing + for this_epoch in epoch_block: + logger.info(' computing cross-spectral density for epoch %d' + % (epoch_idx + 1)) + # con methods and psd are updated inplace + _epoch_spectral_connectivity(data=this_epoch, **call_params) + epoch_idx += 1 + else: + # process epochs in parallel + logger.info( + ' computing cross-spectral density for epochs %d..%d' + % (epoch_idx + 1, epoch_idx + len(epoch_block))) + + out = parallel(my_epoch_spectral_connectivity( + data=this_epoch, **call_params) + for this_epoch in epoch_block) + # do the accumulation + for this_out in out: + for _method, parallel_method in zip(con_methods, this_out[0]): + _method.combine(parallel_method) + if accumulate_psd: + psd += this_out[1] + + epoch_idx += len(epoch_block) + + # normalize + n_epochs = epoch_idx + if accumulate_psd: + psd /= n_epochs + + # compute final connectivity scores + con = list() + patterns = list() + for method_i, conn_method in enumerate(con_methods): + + # future estimators will need to be handled here + if conn_method.accumulate_psd: + # compute scores block-wise to save memory + for i in range(0, n_cons, block_size): + con_idx = slice(i, i + block_size) + psd_xx = psd[idx_map[0][con_idx]] + psd_yy = psd[idx_map[1][con_idx]] + conn_method.compute_con(con_idx, n_epochs, psd_xx, psd_yy) + else: + # compute all scores at once + if method[method_i] in _multivariate_methods: + conn_method.compute_con(indices_use, rank, n_epochs) + else: + conn_method.compute_con(slice(0, n_cons), n_epochs) + + # get the connectivity scores + this_con = conn_method.con_scores + this_patterns = conn_method.patterns + + if this_con.shape[0] != n_cons: + raise RuntimeError( + 'first dimension of connectivity scores does not match the ' + 'number of connections; please contact the mne-connectivity ' + 'developers') + if faverage: + if this_con.shape[1] != n_freqs: + raise RuntimeError( + 'second dimension of connectivity scores does not match ' + 'the number of frequencies; please contact the ' + 'mne-connectivity developers') + con_shape = (n_cons, n_bands) + this_con.shape[2:] + this_con_bands = np.empty(con_shape, dtype=this_con.dtype) + for band_idx in range(n_bands): + this_con_bands[:, band_idx] = np.mean( + this_con[:, freq_idx_bands[band_idx]], axis=1) + this_con = this_con_bands + + if this_patterns is not None: + patterns_shape = list(this_patterns.shape) + patterns_shape[3] = n_bands + this_patterns_bands = np.empty(patterns_shape, + dtype=this_patterns.dtype) + for band_idx in range(n_bands): + this_patterns_bands[:, :, :, band_idx] = np.mean( + this_patterns[:, :, :, freq_idx_bands[band_idx]], + axis=3) + this_patterns = this_patterns_bands + + con.append(this_con) + patterns.append(this_patterns) -def _store_results( - con, patterns, method, freqs, faverage, freqs_bands, names, mode, indices, - n_epochs, times, n_tapers, metadata, events, event_id, rank, gc_n_lags, - n_signals -): - """Store results in connectivity containers.""" freqs_used = freqs if faverage: # for each band we return the frequencies that were averaged @@ -572,6 +2090,23 @@ def _store_results( freqs_used = freqs_bands freqs_used = [[np.min(band), np.max(band)] for band in freqs_used] + if indices is None and not multivariate_con: + # return all-to-all connectivity matrices + # raveled into a 1D array + logger.info(' assembling connectivity matrix') + con_flat = con + con = list() + for this_con_flat in con_flat: + this_con = np.zeros((n_signals, n_signals) + + this_con_flat.shape[1:], + dtype=this_con_flat.dtype) + this_con[indices_use] = this_con_flat + + # ravel 2D connectivity into a 1D array + # while keeping other dimensions + this_con = this_con.reshape((n_signals ** 2,) + + this_con_flat.shape[1:]) + con.append(this_con) # number of nodes in the original data n_nodes = n_signals @@ -596,7 +2131,7 @@ def _store_results( logger.info('[Connectivity computation done]') - if len(method) == 1: + if n_methods == 1: # for a single method return connectivity directly conn_list = conn_list[0] diff --git a/mne_connectivity/spectral/epochs_bivariate.py b/mne_connectivity/spectral/epochs_bivariate.py deleted file mode 100644 index 044de3b4..00000000 --- a/mne_connectivity/spectral/epochs_bivariate.py +++ /dev/null @@ -1,729 +0,0 @@ -# Authors: Martin Luessi -# Denis A. Engemann -# Adam Li -# Thomas S. Binns -# -# License: BSD (3-clause) - -import numpy as np -from mne.utils import logger, verbose - -from .epochs import ( - _AbstractConEstBase, _check_spectral_connectivity_epochs_settings, - _check_spectral_connectivity_epochs_data, _get_n_epochs, - _prepare_connectivity, _assemble_spectral_params, - _compute_spectral_methods_epochs, _store_results) -from ..utils import fill_doc, check_indices - - -def _check_indices(indices, n_signals): - if indices is None: - logger.info('only using indices for lower-triangular matrix') - # only compute r for lower-triangular region - indices_use = np.tril_indices(n_signals, -1) - else: - indices_use = check_indices(indices) - - # number of connectivities to compute - n_cons = len(indices_use[0]) - logger.info(' computing connectivity for %d connections' % n_cons) - - return n_cons, indices_use - - -######################################################################## -# Bivariate connectivity estimators - - -class _EpochMeanConEstBase(_AbstractConEstBase): - """Base class for methods that estimate connectivity as mean epoch-wise.""" - - patterns = None - - def __init__(self, n_cons, n_freqs, n_times): - self.n_cons = n_cons - self.n_freqs = n_freqs - self.n_times = n_times - - if n_times == 0: - self.csd_shape = (n_cons, n_freqs) - else: - self.csd_shape = (n_cons, n_freqs, n_times) - - self.con_scores = None - - def start_epoch(self): # noqa: D401 - """Called at the start of each epoch.""" - pass # for this type of con. method we don't do anything - - def combine(self, other): - """Include con. accumated for some epochs in this estimate.""" - self._acc += other._acc - - -class _CohEstBase(_EpochMeanConEstBase): - """Base Estimator for Coherence, Coherency, Imag. Coherence.""" - - accumulate_psd = True - - def __init__(self, n_cons, n_freqs, n_times): - super(_CohEstBase, self).__init__(n_cons, n_freqs, n_times) - - # allocate space for accumulation of CSD - self._acc = np.zeros(self.csd_shape, dtype=np.complex128) - - def accumulate(self, con_idx, csd_xy): - """Accumulate CSD for some connections.""" - self._acc[con_idx] += csd_xy - - -class _CohEst(_CohEstBase): - """Coherence Estimator.""" - - name = 'Coherence' - - def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - csd_mean = self._acc[con_idx] / n_epochs - self.con_scores[con_idx] = np.abs(csd_mean) / np.sqrt(psd_xx * psd_yy) - - -class _CohyEst(_CohEstBase): - """Coherency Estimator.""" - - name = 'Coherency' - - def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape, - dtype=np.complex128) - csd_mean = self._acc[con_idx] / n_epochs - self.con_scores[con_idx] = csd_mean / np.sqrt(psd_xx * psd_yy) - - -class _ImCohEst(_CohEstBase): - """Imaginary Coherence Estimator.""" - - name = 'Imaginary Coherence' - - def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - csd_mean = self._acc[con_idx] / n_epochs - self.con_scores[con_idx] = np.imag(csd_mean) / np.sqrt(psd_xx * psd_yy) - - -class _PLVEst(_EpochMeanConEstBase): - """PLV Estimator.""" - - name = 'PLV' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_PLVEst, self).__init__(n_cons, n_freqs, n_times) - - # allocate accumulator - self._acc = np.zeros(self.csd_shape, dtype=np.complex128) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - self._acc[con_idx] += csd_xy / np.abs(csd_xy) - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - plv = np.abs(self._acc / n_epochs) - self.con_scores[con_idx] = plv - - -class _ciPLVEst(_EpochMeanConEstBase): - """corrected imaginary PLV Estimator.""" - - name = 'ciPLV' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_ciPLVEst, self).__init__(n_cons, n_freqs, n_times) - - # allocate accumulator - self._acc = np.zeros(self.csd_shape, dtype=np.complex128) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - self._acc[con_idx] += csd_xy / np.abs(csd_xy) - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - imag_plv = np.abs(np.imag(self._acc)) / n_epochs - real_plv = np.real(self._acc) / n_epochs - real_plv = np.clip(real_plv, -1, 1) # bounded from -1 to 1 - mask = (np.abs(real_plv) == 1) # avoid division by 0 - real_plv[mask] = 0 - corrected_imag_plv = imag_plv / np.sqrt(1 - real_plv ** 2) - self.con_scores[con_idx] = corrected_imag_plv - - -class _PLIEst(_EpochMeanConEstBase): - """PLI Estimator.""" - - name = 'PLI' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_PLIEst, self).__init__(n_cons, n_freqs, n_times) - - # allocate accumulator - self._acc = np.zeros(self.csd_shape) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - self._acc[con_idx] += np.sign(np.imag(csd_xy)) - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - pli_mean = self._acc[con_idx] / n_epochs - self.con_scores[con_idx] = np.abs(pli_mean) - - -class _PLIUnbiasedEst(_PLIEst): - """Unbiased PLI Square Estimator.""" - - name = 'Unbiased PLI Square' - accumulate_psd = False - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - pli_mean = self._acc[con_idx] / n_epochs - - # See Vinck paper Eq. (30) - con = (n_epochs * pli_mean ** 2 - 1) / (n_epochs - 1) - - self.con_scores[con_idx] = con - - -class _DPLIEst(_EpochMeanConEstBase): - """DPLI Estimator.""" - - name = 'DPLI' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_DPLIEst, self).__init__(n_cons, n_freqs, n_times) - - # allocate accumulator - self._acc = np.zeros(self.csd_shape) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - self._acc[con_idx] += np.heaviside(np.imag(csd_xy), 0.5) - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - - con = self._acc[con_idx] / n_epochs - - self.con_scores[con_idx] = con - - -class _WPLIEst(_EpochMeanConEstBase): - """WPLI Estimator.""" - - name = 'WPLI' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_WPLIEst, self).__init__(n_cons, n_freqs, n_times) - - # store both imag(csd) and abs(imag(csd)) - acc_shape = (2,) + self.csd_shape - self._acc = np.zeros(acc_shape) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - im_csd = np.imag(csd_xy) - self._acc[0, con_idx] += im_csd - self._acc[1, con_idx] += np.abs(im_csd) - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - - num = np.abs(self._acc[0, con_idx]) - denom = self._acc[1, con_idx] - - # handle zeros in denominator - z_denom = np.where(denom == 0.) - denom[z_denom] = 1. - - con = num / denom - - # where we had zeros in denominator, we set con to zero - con[z_denom] = 0. - - self.con_scores[con_idx] = con - - -class _WPLIDebiasedEst(_EpochMeanConEstBase): - """Debiased WPLI Square Estimator.""" - - name = 'Debiased WPLI Square' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_WPLIDebiasedEst, self).__init__(n_cons, n_freqs, n_times) - # store imag(csd), abs(imag(csd)), imag(csd)^2 - acc_shape = (3,) + self.csd_shape - self._acc = np.zeros(acc_shape) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - im_csd = np.imag(csd_xy) - self._acc[0, con_idx] += im_csd - self._acc[1, con_idx] += np.abs(im_csd) - self._acc[2, con_idx] += im_csd ** 2 - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - - # note: we use the trick from fieldtrip to compute the - # the estimate over all pairwise epoch combinations - sum_im_csd = self._acc[0, con_idx] - sum_abs_im_csd = self._acc[1, con_idx] - sum_sq_im_csd = self._acc[2, con_idx] - - denom = sum_abs_im_csd ** 2 - sum_sq_im_csd - - # handle zeros in denominator - z_denom = np.where(denom == 0.) - denom[z_denom] = 1. - - con = (sum_im_csd ** 2 - sum_sq_im_csd) / denom - - # where we had zeros in denominator, we set con to zero - con[z_denom] = 0. - - self.con_scores[con_idx] = con - - -class _PPCEst(_EpochMeanConEstBase): - """Pairwise Phase Consistency (PPC) Estimator.""" - - name = 'PPC' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_PPCEst, self).__init__(n_cons, n_freqs, n_times) - - # store csd / abs(csd) - self._acc = np.zeros(self.csd_shape, dtype=np.complex128) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - denom = np.abs(csd_xy) - z_denom = np.where(denom == 0.) - denom[z_denom] = 1. - this_acc = csd_xy / denom - this_acc[z_denom] = 0. # handle division by zero - - self._acc[con_idx] += this_acc - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - - # note: we use the trick from fieldtrip to compute the - # the estimate over all pairwise epoch combinations - con = ((self._acc[con_idx] * np.conj(self._acc[con_idx]) - n_epochs) / - (n_epochs * (n_epochs - 1.))) - - self.con_scores[con_idx] = np.real(con) - - -############################################################################### - - -# map names to estimator types -_CON_METHOD_MAP = {'coh': _CohEst, 'cohy': _CohyEst, 'imcoh': _ImCohEst, - 'plv': _PLVEst, 'ciplv': _ciPLVEst, 'ppc': _PPCEst, - 'pli': _PLIEst, 'pli2_unbiased': _PLIUnbiasedEst, - 'dpli': _DPLIEst, 'wpli': _WPLIEst, - 'wpli2_debiased': _WPLIDebiasedEst} - - -@ verbose -@ fill_doc -def spectral_connectivity_epochs( - data, names=None, method='coh', indices=None, sfreq=None, - mode='multitaper', fmin=None, fmax=np.inf, fskip=0, faverage=False, - tmin=None, tmax=None, mt_bandwidth=None, mt_adaptive=False, - mt_low_bias=True, cwt_freqs=None, cwt_n_cycles=7, block_size=1000, - n_jobs=1, verbose=None -): - """Compute bivariate (time-)frequency-domain connectivity measures. - - The connectivity method(s) are specified using the "method" parameter. - All methods are based on estimates of the cross- and power spectral - densities (CSD/PSD) Sxy and Sxx, Syy. - - Parameters - ---------- - data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs - The data from which to compute connectivity. Note that it is also - possible to combine multiple signals by providing a list of tuples, - e.g., data = [(arr_0, stc_0), (arr_1, stc_1), (arr_2, stc_2)], - corresponds to 3 epochs, and arr_* could be an array with the same - number of time points as stc_*. The array-like object can also - be a list/generator of array, shape =(n_signals, n_times), - or a list/generator of SourceEstimate or VolSourceEstimate objects. - %(names)s - method : str | list of str - Connectivity measure(s) to compute. These can be ``['coh', 'cohy', - 'imcoh', 'plv', 'ciplv', 'ppc', 'pli', 'dpli', 'wpli', - 'wpli2_debiased']``. - indices : tuple of array | None - Two arrays with indices of connections for which to compute - connectivity. Each array for the seeds and targets should contain the - channel indices for each bivariate connection. If ``None``, connections - between all channels are computed. - sfreq : float - The sampling frequency. Required if data is not - :class:`Epochs `. - mode : str - Spectrum estimation mode can be either: 'multitaper', 'fourier', or - 'cwt_morlet'. - fmin : float | tuple of float - The lower frequency of interest. Multiple bands are defined using - a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. - fmax : float | tuple of float - The upper frequency of interest. Multiple bands are dedined using - a tuple, e.g. (13., 30.) for two band with 13Hz and 30Hz upper freq. - fskip : int - Omit every "(fskip + 1)-th" frequency bin to decimate in frequency - domain. - faverage : bool - Average connectivity scores for each frequency band. If True, - the output freqs will be a list with arrays of the frequencies - that were averaged. - tmin : float | None - Time to start connectivity estimation. Note: when "data" is an array, - the first sample is assumed to be at time 0. For other types - (Epochs, etc.), the time information contained in the object is used - to compute the time indices. - tmax : float | None - Time to end connectivity estimation. Note: when "data" is an array, - the first sample is assumed to be at time 0. For other types - (Epochs, etc.), the time information contained in the object is used - to compute the time indices. - mt_bandwidth : float | None - The bandwidth of the multitaper windowing function in Hz. - Only used in 'multitaper' mode. - mt_adaptive : bool - Use adaptive weights to combine the tapered spectra into PSD. - Only used in 'multitaper' mode. - mt_low_bias : bool - Only use tapers with more than 90 percent spectral concentration - within bandwidth. Only used in 'multitaper' mode. - cwt_freqs : array - Array of frequencies of interest. Only used in 'cwt_morlet' mode. - cwt_n_cycles : float | array of float - Number of cycles. Fixed number or one per frequency. Only used in - 'cwt_morlet' mode. - block_size : int - How many connections to compute at once (higher numbers are faster - but require more memory). - n_jobs : int - How many samples to process in parallel. - %(verbose)s - - Returns - ------- - con : array | list of array - Computed connectivity measure(s). Either an instance of - ``SpectralConnectivity`` or ``SpectroTemporalConnectivity``. - The shape of the connectivity result will be: - - - ``(n_cons, n_freqs)`` for multitaper or fourier modes - - ``(n_cons, n_freqs, n_times)`` for cwt_morlet mode - - ``n_cons = n_signals ** 2`` with ``indices=None`` - - ``n_cons = len(indices[0])`` when indices is supplied. - - See Also - -------- - mne_connectivity.spectral_connectivity_epochs_multivariate - mne_connectivity.spectral_connectivity_time - mne_connectivity.SpectralConnectivity - mne_connectivity.SpectroTemporalConnectivity - - Notes - ----- - Please note that the interpretation of the measures in this function - depends on the data and underlying assumptions and does not necessarily - reflect a causal relationship between brain regions. - - These measures are not to be interpreted over time. Each Epoch passed into - the dataset is interpreted as an independent sample of the same - connectivity structure. Within each Epoch, it is assumed that the spectral - measure is stationary. The spectral measures implemented in this function - are computed across Epochs. **Thus, spectral measures computed with only - one Epoch will result in errorful values and spectral measures computed - with few Epochs will be unreliable.** Please see - ``spectral_connectivity_time`` for time-resolved connectivity estimation. - - The spectral densities can be estimated using a multitaper method with - digital prolate spheroidal sequence (DPSS) windows, a discrete Fourier - transform with Hanning windows, or a continuous wavelet transform using - Morlet wavelets. The spectral estimation mode is specified using the - "mode" parameter. - - By default, the connectivity between all signals is computed (only - connections corresponding to the lower-triangular part of the connectivity - matrix). If one is only interested in the connectivity between some - signals, the "indices" parameter can be used. For example, to compute the - connectivity between the signal with index 0 and signals "2, 3, 4" (a total - of 3 connections) one can use the following:: - - indices = (np.array([0, 0, 0]), # row indices - np.array([2, 3, 4])) # col indices - - con = spectral_connectivity_epochs(data, method='coh', - indices=indices, ...) - - In this case con.get_data().shape = (3, n_freqs). The connectivity scores - are in the same order as defined indices. - - **Supported Connectivity Measures** - - The connectivity method(s) is specified using the "method" parameter. The - following methods are supported (note: ``E[]`` denotes average over - epochs). Multiple measures can be computed at once by using a list/tuple, - e.g., ``['coh', 'pli']`` to compute coherence and PLI. - - 'coh' : Coherence given by:: - - | E[Sxy] | - C = --------------------- - sqrt(E[Sxx] * E[Syy]) - - 'cohy' : Coherency given by:: - - E[Sxy] - C = --------------------- - sqrt(E[Sxx] * E[Syy]) - - 'imcoh' : Imaginary coherence :footcite:`NolteEtAl2004` given by:: - - Im(E[Sxy]) - C = ---------------------- - sqrt(E[Sxx] * E[Syy]) - - 'plv' : Phase-Locking Value (PLV) :footcite:`LachauxEtAl1999` given - by:: - - PLV = |E[Sxy/|Sxy|]| - - 'ciplv' : corrected imaginary PLV (ciPLV) - :footcite:`BrunaEtAl2018` given by:: - - |E[Im(Sxy/|Sxy|)]| - ciPLV = ------------------------------------ - sqrt(1 - |E[real(Sxy/|Sxy|)]| ** 2) - - 'ppc' : Pairwise Phase Consistency (PPC), an unbiased estimator - of squared PLV :footcite:`VinckEtAl2010`. - - 'pli' : Phase Lag Index (PLI) :footcite:`StamEtAl2007` given by:: - - PLI = |E[sign(Im(Sxy))]| - - 'pli2_unbiased' : Unbiased estimator of squared PLI - :footcite:`VinckEtAl2011`. - - 'dpli' : Directed Phase Lag Index (DPLI) :footcite:`StamEtAl2012` - given by (where H is the Heaviside function):: - - DPLI = E[H(Im(Sxy))] - - 'wpli' : Weighted Phase Lag Index (WPLI) :footcite:`VinckEtAl2011` - given by:: - - |E[Im(Sxy)]| - WPLI = ------------------ - E[|Im(Sxy)|] - - 'wpli2_debiased' : Debiased estimator of squared WPLI - :footcite:`VinckEtAl2011`. - - References - ---------- - .. footbibliography:: - """ - ( - fmin, fmax, n_bands, method, con_method_types, accumulate_psd, - parallel, my_epoch_spectral_connectivity - ) = _check_spectral_connectivity_epochs_settings( - method, fmin, fmax, n_jobs, verbose, _CON_METHOD_MAP) - - (names, times_in, sfreq, events, event_id, - metadata) = _check_spectral_connectivity_epochs_data(data, sfreq, names) - - # loop over data; it could be a generator that returns - # (n_signals x n_times) arrays or SourceEstimates - epoch_idx = 0 - logger.info('Connectivity computation...') - warn_times = True - for epoch_block in _get_n_epochs(data, n_jobs): - if epoch_idx == 0: - # initialize everything times and frequencies - (times, n_times, times_in, n_times_in, tmin_idx, tmax_idx, n_freqs, - freq_mask, freqs, freqs_bands, freq_idx_bands, n_signals, - warn_times) = _prepare_connectivity( - epoch_block=epoch_block, times_in=times_in, tmin=tmin, - tmax=tmax, fmin=fmin, fmax=fmax, sfreq=sfreq, mode=mode, - fskip=fskip, n_bands=n_bands, cwt_freqs=cwt_freqs, - faverage=faverage) - - # check indices input - n_cons, indices_use = _check_indices(indices, n_signals) - - # get the window function, wavelets, etc for different modes - (spectral_params, mt_adaptive, n_times_spectrum, - n_tapers) = _assemble_spectral_params( - mode=mode, n_times=n_times, mt_adaptive=mt_adaptive, - mt_bandwidth=mt_bandwidth, sfreq=sfreq, - mt_low_bias=mt_low_bias, cwt_n_cycles=cwt_n_cycles, - cwt_freqs=cwt_freqs, freqs=freqs, freq_mask=freq_mask) - - # unique signals for which we actually need to compute CSD/PSD - sig_idx = np.unique(np.r_[indices_use[0], indices_use[1]]) - n_signals_use = len(sig_idx) - - # map indices to unique indices - idx_map = [np.searchsorted(sig_idx, ind) for ind in indices_use] - - # allocate space to accumulate PSD - if accumulate_psd: - if n_times_spectrum == 0: - psd_shape = (n_signals_use, n_freqs) - else: - psd_shape = (n_signals_use, n_freqs, n_times_spectrum) - psd = np.zeros(psd_shape) - else: - psd = None - - # create instances of the connectivity estimators - con_methods = [] - for mtype in con_method_types: - con_methods.append(mtype(n_cons=n_cons, n_freqs=n_freqs, - n_times=n_times_spectrum)) - - sep = ', ' - metrics_str = sep.join([meth.name for meth in con_methods]) - logger.info(' the following metrics will be computed: %s' - % metrics_str) - - call_params = dict( - sig_idx=sig_idx, tmin_idx=tmin_idx, tmax_idx=tmax_idx, sfreq=sfreq, - method=method, mode=mode, freq_mask=freq_mask, idx_map=idx_map, - n_cons=n_cons, block_size=block_size, - psd=psd, accumulate_psd=accumulate_psd, - mt_adaptive=mt_adaptive, - con_method_types=con_method_types, - con_methods=con_methods if n_jobs == 1 else None, - n_signals=n_signals, n_signals_use=n_signals_use, n_times=n_times, - gc_n_lags=None, multivariate_con=False, - accumulate_inplace=True if n_jobs == 1 else False) - call_params.update(**spectral_params) - - epoch_idx = _compute_spectral_methods_epochs( - con_methods, epoch_block, epoch_idx, call_params, parallel, - my_epoch_spectral_connectivity, n_jobs, n_times_in, times_in, - warn_times) - - # normalize - n_epochs = epoch_idx - if accumulate_psd: - psd /= n_epochs - - # compute final connectivity scores - con = list() - for conn_method in con_methods: - - # future estimators will need to be handled here - if conn_method.accumulate_psd: - # compute scores block-wise to save memory - for i in range(0, n_cons, block_size): - con_idx = slice(i, i + block_size) - psd_xx = psd[idx_map[0][con_idx]] - psd_yy = psd[idx_map[1][con_idx]] - conn_method.compute_con(con_idx, n_epochs, psd_xx, psd_yy) - else: - # compute all scores at once - conn_method.compute_con(slice(0, n_cons), n_epochs) - - # get the connectivity scores - this_con = conn_method.con_scores - - if this_con.shape[0] != n_cons: - raise RuntimeError( - 'first dimension of connectivity scores does not match the ' - 'number of connections; please contact the mne-connectivity ' - 'developers') - if faverage: - if this_con.shape[1] != n_freqs: - raise RuntimeError( - 'second dimension of connectivity scores does not match ' - 'the number of frequencies; please contact the ' - 'mne-connectivity developers') - con_shape = (n_cons, n_bands) + this_con.shape[2:] - this_con_bands = np.empty(con_shape, dtype=this_con.dtype) - for band_idx in range(n_bands): - this_con_bands[:, band_idx] = np.mean( - this_con[:, freq_idx_bands[band_idx]], axis=1) - this_con = this_con_bands - - con.append(this_con) - # No patterns for bivariate connectivity - patterns = [None for _ in range(len(con))] - - # return all-to-all connectivity matrices raveled into a 1D array - if indices is None: - logger.info(' assembling connectivity matrix') - con_flat = con - con = list() - for this_con_flat in con_flat: - this_con = np.zeros((n_signals, n_signals) + - this_con_flat.shape[1:], - dtype=this_con_flat.dtype) - this_con[indices_use] = this_con_flat - - # ravel 2D connectivity into a 1D array - # while keeping other dimensions - this_con = this_con.reshape((n_signals ** 2,) + - this_con_flat.shape[1:]) - con.append(this_con) - - conn_list = _store_results( - con=con, patterns=patterns, method=method, freqs=freqs, - faverage=faverage, freqs_bands=freqs_bands, names=names, mode=mode, - indices=indices, n_epochs=n_epochs, times=times, n_tapers=n_tapers, - metadata=metadata, events=events, event_id=event_id, rank=None, - gc_n_lags=None, n_signals=n_signals) - - return conn_list diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py deleted file mode 100644 index 28077adb..00000000 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ /dev/null @@ -1,1129 +0,0 @@ -# Authors: Martin Luessi -# Denis A. Engemann -# Adam Li -# Thomas S. Binns -# Tien D. Nguyen -# Richard M. Köhler -# -# License: BSD (3-clause) - -import numpy as np -import scipy as sp -from mne.epochs import BaseEpochs -from mne.parallel import parallel_func -from mne.utils import ProgressBar, logger, verbose - -from .epochs import ( - _AbstractConEstBase, _check_spectral_connectivity_epochs_settings, - _check_spectral_connectivity_epochs_data, _get_n_epochs, - _prepare_connectivity, _assemble_spectral_params, - _compute_spectral_methods_epochs, _store_results) -from ..utils import fill_doc, check_multivariate_indices - - -def _check_indices(indices, method, n_signals): - if indices is None: - if any(this_method in _gc_methods for this_method in method): - raise ValueError( - 'indices must be specified when computing Granger causality, ' - 'as all-to-all connectivity is not supported') - else: - logger.info('using all indices for multivariate connectivity') - indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], - np.arange(n_signals, dtype=int)[np.newaxis, :]) - else: - indices_use = check_multivariate_indices(indices) # pad with -1 - if any(this_method in _gc_methods for this_method in method): - for seed, target in zip(indices[0], indices[1]): - intersection = np.intersect1d(seed, target) - if np.any(intersection != -1): # ignore padded entries - raise ValueError( - 'seed and target indices must not intersect when ' - 'computing Granger causality') - - # number of connectivities to compute - n_cons = len(indices_use[0]) - logger.info(' computing connectivity for %d connections' % n_cons) - - return n_cons, indices_use - - -def _check_rank_input(rank, data, indices): - """Check the rank argument is appropriate and compute rank if missing.""" - sv_tol = 1e-10 # tolerance for non-zero singular val (rel. to largest) - if rank is None: - rank = np.zeros((2, len(indices[0])), dtype=int) - - if isinstance(data, BaseEpochs): - data_arr = data.get_data() - else: - data_arr = data - - # XXX: Unpadding of arrays after already padding them is perhaps not so - # efficient. However, we need to remove the padded values to - # ensure only the correct channels are indexed, and having two - # versions of indices is a bit messy currently. A candidate for - # refactoring to simplify code. - - for group_i in range(2): # seeds and targets - for con_i, con_idcs in enumerate(indices[group_i]): - con_idcs = con_idcs[con_idcs != -1] # -1 is padded value - s = np.linalg.svd(data_arr[:, con_idcs], compute_uv=False) - rank[group_i][con_i] = np.min( - [np.count_nonzero(epoch >= epoch[0] * sv_tol) - for epoch in s]) - - logger.info('Estimated data ranks:') - con_i = 1 - for seed_rank, target_rank in zip(rank[0], rank[1]): - logger.info(' connection %i - seeds (%i); targets (%i)' - % (con_i, seed_rank, target_rank, )) - con_i += 1 - rank = tuple((np.array(rank[0]), np.array(rank[1]))) - - else: - for seed_idcs, target_idcs, seed_rank, target_rank in zip( - indices[0], indices[1], rank[0], rank[1]): - if not (0 < seed_rank <= len(seed_idcs) and - 0 < target_rank <= len(target_idcs)): - raise ValueError( - 'ranks for seeds and targets must be > 0 and <= the ' - 'number of channels in the seeds and targets, ' - 'respectively, for each connection') - - return rank - - -######################################################################## -# Multivariate connectivity estimators - -class _EpochMeanMultivariateConEstBase(_AbstractConEstBase): - """Base class for mean epoch-wise multivar. con. estimation methods.""" - - n_steps = None - patterns = None - - def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): - self.n_signals = n_signals - self.n_cons = n_cons - self.n_freqs = n_freqs - self.n_times = n_times - self.n_jobs = n_jobs - - # include time dimension, even when unused for indexing flexibility - if n_times == 0: - self.csd_shape = (n_signals**2, n_freqs) - self.con_scores = np.zeros((n_cons, n_freqs, 1)) - else: - self.csd_shape = (n_signals**2, n_freqs, n_times) - self.con_scores = np.zeros((n_cons, n_freqs, n_times)) - - # allocate space for accumulation of CSD - self._acc = np.zeros(self.csd_shape, dtype=np.complex128) - - self._compute_n_progress_bar_steps() - - def start_epoch(self): # noqa: D401 - """Called at the start of each epoch.""" - pass # for this type of con. method we don't do anything - - def combine(self, other): - """Include con. accumulated for some epochs in this estimate.""" - self._acc += other._acc - - def accumulate(self, con_idx, csd_xy): - """Accumulate CSD for some connections.""" - self._acc[con_idx] += csd_xy - - def _compute_n_progress_bar_steps(self): - """Calculate the number of steps to include in the progress bar.""" - self.n_steps = int(np.ceil(self.n_freqs / self.n_jobs)) - - def _log_connection_number(self, con_i): - """Log the number of the connection being computed.""" - logger.info('Computing %s for connection %i of %i' - % (self.name, con_i + 1, self.n_cons, )) - - def _get_block_indices(self, block_i, limit): - """Get indices for a computation block capped by a limit.""" - indices = np.arange(block_i * self.n_jobs, (block_i + 1) * self.n_jobs) - - return indices[np.nonzero(indices < limit)] - - def reshape_csd(self): - """Reshape CSD into a matrix of times x freqs x signals x signals.""" - if self.n_times == 0: - return (np.reshape(self._acc, ( - self.n_signals, self.n_signals, self.n_freqs, 1) - ).transpose(3, 2, 0, 1)) - - return (np.reshape(self._acc, ( - self.n_signals, self.n_signals, self.n_freqs, self.n_times) - ).transpose(3, 2, 0, 1)) - - -class _MultivariateCohEstBase(_EpochMeanMultivariateConEstBase): - """Base estimator for multivariate imag. part of coherency methods. - - See Ewald et al. (2012). NeuroImage. DOI: 10.1016/j.neuroimage.2011.11.084 - for equation references. - """ - - name = None - accumulate_psd = False - - def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): - super(_MultivariateCohEstBase, self).__init__( - n_signals, n_cons, n_freqs, n_times, n_jobs) - - def compute_con(self, indices, ranks, n_epochs=1): - """Compute multivariate imag. part of coherency between signals.""" - assert self.name in ['MIC', 'MIM'], ( - 'the class name is not recognised, please contact the ' - 'mne-connectivity developers') - - csd = self.reshape_csd() / n_epochs - n_times = csd.shape[0] - times = np.arange(n_times) - freqs = np.arange(self.n_freqs) - - if self.name == 'MIC': - self.patterns = np.full( - (2, self.n_cons, indices[0].shape[1], self.n_freqs, n_times), - np.nan) - - con_i = 0 - for seed_idcs, target_idcs, seed_rank, target_rank in zip( - indices[0], indices[1], ranks[0], ranks[1]): - self._log_connection_number(con_i) - - seed_idcs = seed_idcs[seed_idcs != -1] - target_idcs = target_idcs[target_idcs != -1] - con_idcs = [*seed_idcs, *target_idcs] - - C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] - - # Eqs. 32 & 33 - C_bar, U_bar_aa, U_bar_bb = self._csd_svd( - C, seed_idcs, seed_rank, target_rank) - - # Eqs. 3 & 4 - E = self._compute_e(C_bar, n_seeds=U_bar_aa.shape[3]) - - if self.name == 'MIC': - self._compute_mic(E, C, seed_idcs, target_idcs, n_times, - U_bar_aa, U_bar_bb, con_i) - else: - self._compute_mim(E, seed_idcs, target_idcs, con_i) - - con_i += 1 - - self.reshape_results() - - def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): - """Dimensionality reduction of CSD with SVD.""" - n_times = csd.shape[0] - n_seeds = len(seed_idcs) - n_targets = csd.shape[3] - n_seeds - - C_aa = csd[..., :n_seeds, :n_seeds] - C_ab = csd[..., :n_seeds, n_seeds:] - C_bb = csd[..., n_seeds:, n_seeds:] - C_ba = csd[..., n_seeds:, :n_seeds] - - # Eq. 32 - if seed_rank != n_seeds: - U_aa = np.linalg.svd(np.real(C_aa), full_matrices=False)[0] - U_bar_aa = U_aa[..., :seed_rank] - else: - U_bar_aa = np.broadcast_to( - np.identity(n_seeds), - (n_times, self.n_freqs) + (n_seeds, n_seeds)) - - if target_rank != n_targets: - U_bb = np.linalg.svd(np.real(C_bb), full_matrices=False)[0] - U_bar_bb = U_bb[..., :target_rank] - else: - U_bar_bb = np.broadcast_to( - np.identity(n_targets), - (n_times, self.n_freqs) + (n_targets, n_targets)) - - # Eq. 33 - C_bar_aa = np.matmul( - U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_aa, U_bar_aa)) - C_bar_ab = np.matmul( - U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_ab, U_bar_bb)) - C_bar_bb = np.matmul( - U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_bb, U_bar_bb)) - C_bar_ba = np.matmul( - U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_ba, U_bar_aa)) - C_bar = np.append(np.append(C_bar_aa, C_bar_ab, axis=3), - np.append(C_bar_ba, C_bar_bb, axis=3), axis=2) - - return C_bar, U_bar_aa, U_bar_bb - - def _compute_e(self, csd, n_seeds): - """Compute E from the CSD.""" - C_r = np.real(csd) - - parallel, parallel_compute_t, _ = parallel_func( - _mic_mim_compute_t, self.n_jobs, verbose=False) - - # imag. part of T filled when data is rank-deficient - T = np.zeros(csd.shape, dtype=np.complex128) - for block_i in ProgressBar( - range(self.n_steps), mesg="frequency blocks"): - freqs = self._get_block_indices(block_i, self.n_freqs) - T[:, freqs] = np.array(parallel(parallel_compute_t( - C_r[:, f], T[:, f], n_seeds) for f in freqs) - ).transpose(1, 0, 2, 3) - - if not np.isreal(T).all() or not np.isfinite(T).all(): - raise RuntimeError( - 'the transformation matrix of the data must be real-valued ' - 'and contain no NaN or infinity values; check that you are ' - 'using full rank data or specify an appropriate rank for the ' - 'seeds and targets that is less than or equal to their ranks') - T = np.real(T) # make T real if check passes - - # Eq. 4 - D = np.matmul(T, np.matmul(csd, T)) - - # E as imag. part of D between seeds and targets - return np.imag(D[..., :n_seeds, n_seeds:]) - - def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, - U_bar_bb, con_i): - """Compute MIC and the associated spatial patterns.""" - n_seeds = len(seed_idcs) - n_targets = len(target_idcs) - times = np.arange(n_times) - freqs = np.arange(self.n_freqs) - - # Eigendecomp. to find spatial filters for seeds and targets - w_seeds, V_seeds = np.linalg.eigh( - np.matmul(E, E.transpose(0, 1, 3, 2))) - w_targets, V_targets = np.linalg.eigh( - np.matmul(E.transpose(0, 1, 3, 2), E)) - if ( - len(seed_idcs) == len(target_idcs) and - np.all(np.sort(seed_idcs) == np.sort(target_idcs)) - ): - # strange edge-case where the eigenvectors returned should be a set - # of identity matrices with one rotated by 90 degrees, but are - # instead identical (i.e. are not rotated versions of one another). - # This leads to the case where the spatial filters are incorrectly - # applied, resulting in connectivity estimates of ~0 when they - # should be perfectly correlated ~1. Accordingly, we manually - # create a set of rotated identity matrices to use as the filters. - create_filter = False - stop = False - while not create_filter and not stop: - for time_i in range(n_times): - for freq_i in range(self.n_freqs): - if np.all(V_seeds[time_i, freq_i] == - V_targets[time_i, freq_i]): - create_filter = True - break - stop = True - if create_filter: - n_chans = E.shape[2] - eye_4d = np.zeros_like(V_seeds) - eye_4d[:, :, np.arange(n_chans), np.arange(n_chans)] = 1 - V_seeds = eye_4d - V_targets = np.rot90(eye_4d, axes=(2, 3)) - - # Spatial filters with largest eigval. for seeds and targets - alpha = V_seeds[times[:, None], freqs, :, w_seeds.argmax(axis=2)] - beta = V_targets[times[:, None], freqs, :, w_targets.argmax(axis=2)] - - # Eq. 46 (seed spatial patterns) - self.patterns[0, con_i, :n_seeds] = (np.matmul( - np.real(C[..., :n_seeds, :n_seeds]), - np.matmul(U_bar_aa, np.expand_dims(alpha, axis=3))))[..., 0].T - - # Eq. 47 (target spatial patterns) - self.patterns[1, con_i, :n_targets] = (np.matmul( - np.real(C[..., n_seeds:, n_seeds:]), - np.matmul(U_bar_bb, np.expand_dims(beta, axis=3))))[..., 0].T - - # Eq. 7 - self.con_scores[con_i] = (np.einsum( - 'ijk,ijk->ij', alpha, np.matmul(E, np.expand_dims( - beta, axis=3))[..., 0] - ) / np.linalg.norm(alpha, axis=2) * np.linalg.norm(beta, axis=2)).T - - def _compute_mim(self, E, seed_idcs, target_idcs, con_i): - """Compute MIM (a.k.a. GIM if seeds == targets).""" - # Eq. 14 - self.con_scores[con_i] = np.matmul( - E, E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T - - # Eq. 15 - if ( - len(seed_idcs) == len(target_idcs) and - np.all(np.sort(seed_idcs) == np.sort(target_idcs)) - ): - self.con_scores[con_i] *= 0.5 - - def reshape_results(self): - """Remove time dimension from results, if necessary.""" - if self.n_times == 0: - self.con_scores = self.con_scores[..., 0] - if self.patterns is not None: - self.patterns = self.patterns[..., 0] - - -def _mic_mim_compute_t(C, T, n_seeds): - """Compute T for a single frequency (used for MIC and MIM).""" - for time_i in range(C.shape[0]): - T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power( - C[time_i, :n_seeds, :n_seeds], -0.5 - ) - T[time_i, n_seeds:, n_seeds:] = sp.linalg.fractional_matrix_power( - C[time_i, n_seeds:, n_seeds:], -0.5 - ) - - return T - - -class _MICEst(_MultivariateCohEstBase): - """Multivariate imaginary part of coherency (MIC) estimator.""" - - name = "MIC" - - -class _MIMEst(_MultivariateCohEstBase): - """Multivariate interaction measure (MIM) estimator.""" - - name = "MIM" - - -class _GCEstBase(_EpochMeanMultivariateConEstBase): - """Base multivariate state-space Granger causality estimator.""" - - accumulate_psd = False - - def __init__(self, n_signals, n_cons, n_freqs, n_times, n_lags, n_jobs=1): - super(_GCEstBase, self).__init__( - n_signals, n_cons, n_freqs, n_times, n_jobs) - - self.freq_res = (self.n_freqs - 1) * 2 - if n_lags >= self.freq_res: - raise ValueError( - 'the number of lags (%i) must be less than double the ' - 'frequency resolution (%i)' % (n_lags, self.freq_res, )) - self.n_lags = n_lags - - def compute_con(self, indices, ranks, n_epochs=1): - """Compute multivariate state-space Granger causality.""" - assert self.name in ['GC', 'GC time-reversed'], ( - 'the class name is not recognised, please contact the ' - 'mne-connectivity developers') - - csd = self.reshape_csd() / n_epochs - - n_times = csd.shape[0] - times = np.arange(n_times) - freqs = np.arange(self.n_freqs) - - con_i = 0 - for seed_idcs, target_idcs, seed_rank, target_rank in zip( - indices[0], indices[1], ranks[0], ranks[1]): - self._log_connection_number(con_i) - - seed_idcs = seed_idcs[seed_idcs != -1] - target_idcs = target_idcs[target_idcs != -1] - con_idcs = [*seed_idcs, *target_idcs] - - C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] - - C_bar = self._csd_svd(C, seed_idcs, seed_rank, target_rank) - n_signals = seed_rank + target_rank - con_seeds = np.arange(seed_rank) - con_targets = np.arange(target_rank) + seed_rank - - autocov = self._compute_autocov(C_bar) - if self.name == "GC time-reversed": - autocov = autocov.transpose(0, 1, 3, 2) - - A_f, V = self._autocov_to_full_var(autocov) - A_f_3d = np.reshape( - A_f, (n_times, n_signals, n_signals * self.n_lags), order="F") - A, K = self._full_var_to_iss(A_f_3d) - - self.con_scores[con_i] = self._iss_to_ugc( - A, A_f_3d, K, V, con_seeds, con_targets) - - con_i += 1 - - self.reshape_results() - - def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): - """Dimensionality reduction of CSD with SVD on the covariance.""" - # sum over times and epochs to get cov. from CSD - cov = csd.sum(axis=(0, 1)) - - n_seeds = len(seed_idcs) - n_targets = csd.shape[3] - n_seeds - - cov_aa = cov[:n_seeds, :n_seeds] - cov_bb = cov[n_seeds:, n_seeds:] - - if seed_rank != n_seeds: - U_aa = np.linalg.svd(np.real(cov_aa), full_matrices=False)[0] - U_bar_aa = U_aa[:, :seed_rank] - else: - U_bar_aa = np.identity(n_seeds) - - if target_rank != n_targets: - U_bb = np.linalg.svd(np.real(cov_bb), full_matrices=False)[0] - U_bar_bb = U_bb[:, :target_rank] - else: - U_bar_bb = np.identity(n_targets) - - C_aa = csd[..., :n_seeds, :n_seeds] - C_ab = csd[..., :n_seeds, n_seeds:] - C_bb = csd[..., n_seeds:, n_seeds:] - C_ba = csd[..., n_seeds:, :n_seeds] - - C_bar_aa = np.matmul( - U_bar_aa.transpose(1, 0), np.matmul(C_aa, U_bar_aa)) - C_bar_ab = np.matmul( - U_bar_aa.transpose(1, 0), np.matmul(C_ab, U_bar_bb)) - C_bar_bb = np.matmul( - U_bar_bb.transpose(1, 0), np.matmul(C_bb, U_bar_bb)) - C_bar_ba = np.matmul( - U_bar_bb.transpose(1, 0), np.matmul(C_ba, U_bar_aa)) - C_bar = np.append(np.append(C_bar_aa, C_bar_ab, axis=3), - np.append(C_bar_ba, C_bar_bb, axis=3), axis=2) - - return C_bar - - def _compute_autocov(self, csd): - """Compute autocovariance from the CSD.""" - n_times = csd.shape[0] - n_signals = csd.shape[2] - - circular_shifted_csd = np.concatenate( - [np.flip(np.conj(csd[:, 1:]), axis=1), csd[:, :-1]], axis=1) - ifft_shifted_csd = self._block_ifft( - circular_shifted_csd, self.freq_res) - lags_ifft_shifted_csd = np.reshape( - ifft_shifted_csd[:, :self.n_lags + 1], - (n_times, self.n_lags + 1, n_signals ** 2), order="F") - - signs = np.repeat([1], self.n_lags + 1).tolist() - signs[1::2] = [x * -1 for x in signs[1::2]] - sign_matrix = np.repeat( - np.tile(np.array(signs), (n_signals ** 2, 1))[np.newaxis], - n_times, axis=0).transpose(0, 2, 1) - - return np.real(np.reshape( - sign_matrix * lags_ifft_shifted_csd, - (n_times, self.n_lags + 1, n_signals, n_signals), order="F")) - - def _block_ifft(self, csd, n_points): - """Compute block iFFT with n points.""" - shape = csd.shape - csd_3d = np.reshape( - csd, (shape[0], shape[1], shape[2] * shape[3]), order="F") - - csd_ifft = np.fft.ifft(csd_3d, n=n_points, axis=1) - - return np.reshape(csd_ifft, shape, order="F") - - def _autocov_to_full_var(self, autocov): - """Compute full VAR model using Whittle's LWR recursion.""" - if np.any(np.linalg.det(autocov) == 0): - raise RuntimeError( - 'the autocovariance matrix is singular; check if your data is ' - 'rank deficient and specify an appropriate rank argument <= ' - 'the rank of the seeds and targets') - - A_f, V = self._whittle_lwr_recursion(autocov) - - if not np.isfinite(A_f).all(): - raise RuntimeError('at least one VAR model coefficient is ' - 'infinite or NaN; check the data you are using') - - try: - np.linalg.cholesky(V) - except np.linalg.LinAlgError as np_error: - raise RuntimeError( - 'the covariance matrix of the residuals is not ' - 'positive-definite; check the singular values of your data ' - 'and specify an appropriate rank argument <= the rank of the ' - 'seeds and targets') from np_error - - return A_f, V - - def _whittle_lwr_recursion(self, G): - """Solve Yule-Walker eqs. for full VAR params. with LWR recursion. - - See: Whittle P., 1963. Biometrika, DOI: 10.1093/biomet/50.1-2.129 - """ - # Initialise recursion - n = G.shape[2] # number of signals - q = G.shape[1] - 1 # number of lags - t = G.shape[0] # number of times - qn = n * q - - cov = G[:, 0, :, :] # covariance - G_f = np.reshape( - G[:, 1:, :, :].transpose(0, 3, 1, 2), (t, qn, n), - order="F") # forward autocov - G_b = np.reshape( - np.flip(G[:, 1:, :, :], 1).transpose(0, 3, 2, 1), (t, n, qn), - order="F").transpose(0, 2, 1) # backward autocov - - A_f = np.zeros((t, n, qn)) # forward coefficients - A_b = np.zeros((t, n, qn)) # backward coefficients - - k = 1 # model order - r = q - k - k_f = np.arange(k * n) # forward indices - k_b = np.arange(r * n, qn) # backward indices - - try: - A_f[:, :, k_f] = np.linalg.solve( - cov, G_b[:, k_b, :].transpose(0, 2, 1)).transpose(0, 2, 1) - A_b[:, :, k_b] = np.linalg.solve( - cov, G_f[:, k_f, :].transpose(0, 2, 1)).transpose(0, 2, 1) - - # Perform recursion - for k in np.arange(2, q + 1): - var_A = (G_b[:, (r - 1) * n: r * n, :] - - np.matmul(A_f[:, :, k_f], G_b[:, k_b, :])) - var_B = cov - np.matmul(A_b[:, :, k_b], G_b[:, k_b, :]) - AA_f = np.linalg.solve( - var_B, var_A.transpose(0, 2, 1)).transpose(0, 2, 1) - - var_A = (G_f[:, (k - 1) * n: k * n, :] - - np.matmul(A_b[:, :, k_b], G_f[:, k_f, :])) - var_B = cov - np.matmul(A_f[:, :, k_f], G_f[:, k_f, :]) - AA_b = np.linalg.solve( - var_B, var_A.transpose(0, 2, 1)).transpose(0, 2, 1) - - A_f_previous = A_f[:, :, k_f] - A_b_previous = A_b[:, :, k_b] - - r = q - k - k_f = np.arange(k * n) - k_b = np.arange(r * n, qn) - - A_f[:, :, k_f] = np.dstack( - (A_f_previous - np.matmul(AA_f, A_b_previous), AA_f)) - A_b[:, :, k_b] = np.dstack( - (AA_b, A_b_previous - np.matmul(AA_b, A_f_previous))) - except np.linalg.LinAlgError as np_error: - raise RuntimeError( - 'the autocovariance matrix is singular; check if your data is ' - 'rank deficient and specify an appropriate rank argument <= ' - 'the rank of the seeds and targets') from np_error - - V = cov - np.matmul(A_f, G_f) - A_f = np.reshape(A_f, (t, n, n, q), order="F") - - return A_f, V - - def _full_var_to_iss(self, A_f): - """Compute innovations-form parameters for a state-space model. - - Parameters computed from a full VAR model using Aoki's method. For a - non-moving-average full VAR model, the state-space parameter C - (observation matrix) is identical to AF of the VAR model. - - See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: - 10.1103/PhysRevE.91.040101. - """ - t = A_f.shape[0] - m = A_f.shape[1] # number of signals - p = A_f.shape[2] // m # number of autoregressive lags - - I_p = np.dstack(t * [np.eye(m * p)]).transpose(2, 0, 1) - A = np.hstack((A_f, I_p[:, : (m * p - m), :])) # state transition - # matrix - K = np.hstack(( - np.dstack(t * [np.eye(m)]).transpose(2, 0, 1), - np.zeros((t, (m * (p - 1)), m)))) # Kalman gain matrix - - return A, K - - def _iss_to_ugc(self, A, C, K, V, seeds, targets): - """Compute unconditional GC from innovations-form state-space params. - - See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: - 10.1103/PhysRevE.91.040101. - """ - times = np.arange(A.shape[0]) - freqs = np.arange(self.n_freqs) - z = np.exp(-1j * np.pi * np.linspace(0, 1, self.n_freqs)) # points - # on a unit circle in the complex plane, one for each frequency - - H = self._iss_to_tf(A, C, K, z) # spectral transfer function - V_22_1 = np.linalg.cholesky(self._partial_covar(V, seeds, targets)) - HV = np.matmul(H, np.linalg.cholesky(V)) - S = np.matmul(HV, HV.conj().transpose(0, 1, 3, 2)) # Eq. 6 - S_11 = S[np.ix_(freqs, times, targets, targets)] - HV_12 = np.matmul(H[np.ix_(freqs, times, targets, seeds)], V_22_1) - HVH = np.matmul(HV_12, HV_12.conj().transpose(0, 1, 3, 2)) - - # Eq. 11 - return np.real( - np.log(np.linalg.det(S_11)) - np.log(np.linalg.det(S_11 - HVH))) - - def _iss_to_tf(self, A, C, K, z): - """Compute transfer function for innovations-form state-space params. - - In the frequency domain, the back-shift operator, z, is a vector of - points on a unit circle in the complex plane. z = e^-iw, where -pi < w - <= pi. - - A note on efficiency: solving over the 4D time-freq. tensor is slower - than looping over times and freqs when n_times and n_freqs high, and - when n_times and n_freqs low, looping over times and freqs very fast - anyway (plus tensor solving doesn't allow for parallelisation). - - See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: - 10.1103/PhysRevE.91.040101. - """ - t = A.shape[0] - h = self.n_freqs - n = C.shape[1] - m = A.shape[1] - I_n = np.eye(n) - I_m = np.eye(m) - H = np.zeros((h, t, n, n), dtype=np.complex128) - - parallel, parallel_compute_H, _ = parallel_func( - _gc_compute_H, self.n_jobs, verbose=False - ) - H = np.zeros((h, t, n, n), dtype=np.complex128) - for block_i in ProgressBar( - range(self.n_steps), mesg="frequency blocks" - ): - freqs = self._get_block_indices(block_i, self.n_freqs) - H[freqs] = parallel( - parallel_compute_H(A, C, K, z[k], I_n, I_m) for k in freqs) - - return H - - def _partial_covar(self, V, seeds, targets): - """Compute partial covariance of a matrix. - - Given a covariance matrix V, the partial covariance matrix of V between - indices i and j, given k (V_ij|k), is equivalent to V_ij - V_ik * - V_kk^-1 * V_kj. In this case, i and j are seeds, and k are targets. - - See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: - 10.1103/PhysRevE.91.040101. - """ - times = np.arange(V.shape[0]) - W = np.linalg.solve( - np.linalg.cholesky(V[np.ix_(times, targets, targets)]), - V[np.ix_(times, targets, seeds)], - ) - W = np.matmul(W.transpose(0, 2, 1), W) - - return V[np.ix_(times, seeds, seeds)] - W - - def reshape_results(self): - """Remove time dimension from con. scores, if necessary.""" - if self.n_times == 0: - self.con_scores = self.con_scores[:, :, 0] - - -def _gc_compute_H(A, C, K, z_k, I_n, I_m): - """Compute transfer function for innovations-form state-space params. - - See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: - 10.1103/PhysRevE.91.040101, Eq. 4. - """ - from scipy import linalg # XXX: is this necessary??? - H = np.zeros((A.shape[0], C.shape[1], C.shape[1]), dtype=np.complex128) - for t in range(A.shape[0]): - H[t] = I_n + np.matmul( - C[t], linalg.lu_solve(linalg.lu_factor(z_k * I_m - A[t]), K[t])) - - return H - - -class _GCEst(_GCEstBase): - """[seeds -> targets] state-space GC estimator.""" - - name = "GC" - - -class _GCTREst(_GCEstBase): - """time-reversed[seeds -> targets] state-space GC estimator.""" - - name = "GC time-reversed" - -############################################################################### - - -# map names to estimator types -_CON_METHOD_MAP = {'mic': _MICEst, 'mim': _MIMEst, 'gc': _GCEst, - 'gc_tr': _GCTREst} - -_gc_methods = ['gc', 'gc_tr'] - - -@ verbose -@ fill_doc -def spectral_connectivity_epochs_multivariate( - data, names=None, method='mic', indices=None, sfreq=None, - mode='multitaper', fmin=None, fmax=np.inf, fskip=0, faverage=False, - tmin=None, tmax=None, mt_bandwidth=None, mt_adaptive=False, - mt_low_bias=True, cwt_freqs=None, cwt_n_cycles=7, gc_n_lags=40, rank=None, - block_size=1000, n_jobs=1, verbose=None -): - r"""Compute multivariate (time-)frequency-domain connectivity measures. - - The connectivity method(s) are specified using the "method" parameter. - All methods are based on estimates of the cross-spectral density (CSD). - - Parameters - ---------- - data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs - The data from which to compute connectivity. Note that it is also - possible to combine multiple signals by providing a list of tuples, - e.g., data = [(arr_0, stc_0), (arr_1, stc_1), (arr_2, stc_2)], - corresponds to 3 epochs, and arr_* could be an array with the same - number of time points as stc_*. The array-like object can also - be a list/generator of array, shape =(n_signals, n_times), - or a list/generator of SourceEstimate or VolSourceEstimate objects. - %(names)s - method : str | list of str - Connectivity measure(s) to compute. These can be ``['mic', 'mim', 'gc', - 'gc_tr']``. - indices : tuple of array | None - Two arrays with indices of connections for which to compute - connectivity. Each array for the seeds and targets should consist of - nested arrays containing the channel indices for each multivariate - connection. If ``None``, connections between all channels are computed, - unless a Granger causality method is called, in which case an error is - raised. - sfreq : float - The sampling frequency. Required if data is not - :class:`Epochs `. - mode : str - Spectrum estimation mode can be either: 'multitaper', 'fourier', or - 'cwt_morlet'. - fmin : float | tuple of float - The lower frequency of interest. Multiple bands are defined using - a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. - fmax : float | tuple of float - The upper frequency of interest. Multiple bands are dedined using - a tuple, e.g. (13., 30.) for two band with 13Hz and 30Hz upper freq. - fskip : int - Omit every "(fskip + 1)-th" frequency bin to decimate in frequency - domain. - faverage : bool - Average connectivity scores for each frequency band. If True, - the output freqs will be a list with arrays of the frequencies - that were averaged. - tmin : float | None - Time to start connectivity estimation. Note: when "data" is an array, - the first sample is assumed to be at time 0. For other types - (Epochs, etc.), the time information contained in the object is used - to compute the time indices. - tmax : float | None - Time to end connectivity estimation. Note: when "data" is an array, - the first sample is assumed to be at time 0. For other types - (Epochs, etc.), the time information contained in the object is used - to compute the time indices. - mt_bandwidth : float | None - The bandwidth of the multitaper windowing function in Hz. - Only used in 'multitaper' mode. - mt_adaptive : bool - Use adaptive weights to combine the tapered spectra into PSD. - Only used in 'multitaper' mode. - mt_low_bias : bool - Only use tapers with more than 90 percent spectral concentration - within bandwidth. Only used in 'multitaper' mode. - cwt_freqs : array - Array of frequencies of interest. Only used in 'cwt_morlet' mode. - cwt_n_cycles : float | array of float - Number of cycles. Fixed number or one per frequency. Only used in - 'cwt_morlet' mode. - gc_n_lags : int - Number of lags to use when computing Granger causality (the vector - autoregressive model order). Higher values increase computational cost, - but reduce the degree of spectral smoothing in the results. Must be < - (n_freqs - 1) * 2. Only used if ``method`` contains any of ``['gc', - 'gc_tr']``. - rank : tuple of array | None - Two arrays with the rank to project the seed and target data to, - respectively, using singular value decomposition. If None, the rank of - the data is computed and projected to. Only used if ``method`` contains - any of ``['mic', 'mim', 'gc', 'gc_tr']``. - block_size : int - How many CSD entries to compute at once (higher numbers are faster but - require more memory). - n_jobs : int - How many samples to process in parallel. - %(verbose)s - - Returns - ------- - con : array | list of array - Computed connectivity measure(s). Either an instance of - ``SpectralConnectivity`` or ``SpectroTemporalConnectivity``. - The shape of the connectivity result will be: - - - ``(n_cons, n_freqs)`` for multitaper or fourier modes - - ``(n_cons, n_freqs, n_times)`` for cwt_morlet mode - - ``n_cons = 1`` when ``indices=None`` - - ``n_cons = len(indices[0])`` when indices is supplied - - See Also - -------- - mne_connectivity.spectral_connectivity_epochs - mne_connectivity.spectral_connectivity_time - mne_connectivity.SpectralConnectivity - mne_connectivity.SpectroTemporalConnectivity - - Notes - ----- - Please note that the interpretation of the measures in this function - depends on the data and underlying assumptions and does not necessarily - reflect a causal relationship between brain regions. - - These measures are not to be interpreted over time. Each Epoch passed into - the dataset is interpreted as an independent sample of the same - connectivity structure. Within each Epoch, it is assumed that the spectral - measure is stationary. The spectral measures implemented in this function - are computed across Epochs. **Thus, spectral measures computed with only - one Epoch will result in errorful values and spectral measures computed - with few Epochs will be unreliable.** Please see - ``spectral_connectivity_time`` for time-resolved connectivity estimation. - - The spectral densities can be estimated using a multitaper method with - digital prolate spheroidal sequence (DPSS) windows, a discrete Fourier - transform with Hanning windows, or a continuous wavelet transform using - Morlet wavelets. The spectral estimation mode is specified using the - "mode" parameter. - - By default, "indices" is None, and the connectivity between all signals is - computed and a single connectivity spectrum will be returned (this is not - possible if a Granger causality method is called). If one is only - interested in the connectivity between some signals, the "indices" - parameter can be used. Seed and target indices for each connection should - be specified as nested array-likes. For example, to compute the - connectivity between signals (0, 1) -> (2, 3) and (0, 1) -> (4, 5), indices - should be specified as:: - - indices = ([[0, 1], [0, 1]], # seeds - [[2, 3], [4, 5]]) # targets - - More information on working with multivariate indices and handling - connections where the number of seeds and targets are not equal can be - found in the :doc:`../auto_examples/handling_ragged_arrays` example. - - **Supported Connectivity Measures** - - The connectivity method(s) is specified using the "method" parameter. - Multiple measures can be computed at once by using a list/tuple, e.g., - ``['mic', 'gc']``. The following methods are supported: - - 'mic' : Maximised Imaginary part of Coherency (MIC) - :footcite:`EwaldEtAl2012` given by: - - :math:`MIC=\Large{\frac{\boldsymbol{\alpha}^T \boldsymbol{E \beta}} - {\parallel\boldsymbol{\alpha}\parallel \parallel\boldsymbol{\beta} - \parallel}}` - - where: :math:`\boldsymbol{E}` is the imaginary part of the - transformed cross-spectral density between seeds and targets; and - :math:`\boldsymbol{\alpha}` and :math:`\boldsymbol{\beta}` are - eigenvectors for the seeds and targets, such that - :math:`\boldsymbol{\alpha}^T \boldsymbol{E \beta}` maximises - connectivity between the seeds and targets. - - 'mim' : Multivariate Interaction Measure (MIM) - :footcite:`EwaldEtAl2012` given by: - - :math:`MIM=tr(\boldsymbol{EE}^T)` - - 'gc' : State-space Granger Causality (GC) :footcite:`BarnettSeth2015` - given by: - - :math:`GC = ln\Large{(\frac{\lvert\boldsymbol{S}_{tt}\rvert}{\lvert - \boldsymbol{S}_{tt}-\boldsymbol{H}_{ts}\boldsymbol{\Sigma}_{ss - \lvert t}\boldsymbol{H}_{ts}^*\rvert}})`, - - where: :math:`s` and :math:`t` represent the seeds and targets, - respectively; :math:`\boldsymbol{H}` is the spectral transfer - function; :math:`\boldsymbol{\Sigma}` is the residuals matrix of - the autoregressive model; and :math:`\boldsymbol{S}` is - :math:`\boldsymbol{\Sigma}` transformed by :math:`\boldsymbol{H}`. - - 'gc_tr' : State-space GC on time-reversed signals - :footcite:`BarnettSeth2015,WinklerEtAl2016` given by the same equation - as for 'gc', but where the autocovariance sequence from which the - autoregressive model is produced is transposed to mimic the reversal of - the original signal in time. - - References - ---------- - .. footbibliography:: - """ - ( - fmin, fmax, n_bands, method, con_method_types, accumulate_psd, - parallel, my_epoch_spectral_connectivity - ) = _check_spectral_connectivity_epochs_settings( - method, fmin, fmax, n_jobs, verbose, _CON_METHOD_MAP) - - if n_bands != 1 and any( - this_method in _gc_methods for this_method in method - ): - raise ValueError('computing Granger causality on multiple frequency ' - 'bands is not yet supported') - - (names, times_in, sfreq, events, event_id, - metadata) = _check_spectral_connectivity_epochs_data(data, sfreq, names) - - # loop over data; it could be a generator that returns - # (n_signals x n_times) arrays or SourceEstimates - epoch_idx = 0 - logger.info('Connectivity computation...') - warn_times = True - for epoch_block in _get_n_epochs(data, n_jobs): - if epoch_idx == 0: - # initialize everything times and frequencies - (times, n_times, times_in, n_times_in, tmin_idx, tmax_idx, n_freqs, - freq_mask, freqs, freqs_bands, freq_idx_bands, n_signals, - warn_times) = _prepare_connectivity( - epoch_block=epoch_block, times_in=times_in, tmin=tmin, - tmax=tmax, fmin=fmin, fmax=fmax, sfreq=sfreq, mode=mode, - fskip=fskip, n_bands=n_bands, cwt_freqs=cwt_freqs, - faverage=faverage) - - # check indices input - n_cons, indices_use = _check_indices(indices, method, n_signals) - - # check rank input and compute data ranks - rank = _check_rank_input(rank, data, indices_use) - - # make sure padded indices are stored in the connectivity object - if indices is not None: - indices = tuple(np.array(indices_use)) # create a copy - - # get the window function, wavelets, etc for different modes - (spectral_params, mt_adaptive, n_times_spectrum, - n_tapers) = _assemble_spectral_params( - mode=mode, n_times=n_times, mt_adaptive=mt_adaptive, - mt_bandwidth=mt_bandwidth, sfreq=sfreq, - mt_low_bias=mt_low_bias, cwt_n_cycles=cwt_n_cycles, - cwt_freqs=cwt_freqs, freqs=freqs, freq_mask=freq_mask) - - # unique signals for which we actually need to compute CSD - sig_idx = np.unique(np.concatenate(np.concatenate( - indices_use))) - sig_idx = sig_idx[sig_idx != -1] - remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(sig_idx)} - remapping[-1] = -1 - remapped_inds = (indices_use[0].copy(), indices_use[1].copy()) - con_i = 0 - for seed, target in zip(indices_use[0], indices_use[1]): - remapped_inds[0][con_i] = np.array([ - remapping[idx] for idx in seed]) - remapped_inds[1][con_i] = np.array([ - remapping[idx] for idx in target]) - con_i += 1 - remapped_sig = [remapping[idx] for idx in sig_idx] - n_signals_use = len(sig_idx) - - # map indices to unique indices - indices_use = remapped_inds # use remapped seeds & targets - idx_map = [np.sort(np.repeat(remapped_sig, len(sig_idx))), - np.tile(remapped_sig, len(sig_idx))] - - # create instances of the connectivity estimators - con_methods = [] - for mtype_i, mtype in enumerate(con_method_types): - method_params = dict(n_cons=n_cons, n_freqs=n_freqs, - n_times=n_times_spectrum, - n_signals=n_signals_use) - if method[mtype_i] in _gc_methods: - method_params.update(dict(n_lags=gc_n_lags)) - con_methods.append(mtype(**method_params)) - - sep = ', ' - metrics_str = sep.join([meth.name for meth in con_methods]) - logger.info(' the following metrics will be computed: %s' - % metrics_str) - - call_params = dict( - sig_idx=sig_idx, tmin_idx=tmin_idx, tmax_idx=tmax_idx, sfreq=sfreq, - method=method, mode=mode, freq_mask=freq_mask, idx_map=idx_map, - n_cons=n_cons, block_size=block_size, - psd=None, accumulate_psd=accumulate_psd, - mt_adaptive=mt_adaptive, - con_method_types=con_method_types, - con_methods=con_methods if n_jobs == 1 else None, - n_signals=n_signals, n_signals_use=n_signals_use, n_times=n_times, - gc_n_lags=gc_n_lags, multivariate_con=True, - accumulate_inplace=True if n_jobs == 1 else False) - call_params.update(**spectral_params) - - epoch_idx = _compute_spectral_methods_epochs( - con_methods, epoch_block, epoch_idx, call_params, parallel, - my_epoch_spectral_connectivity, n_jobs, n_times_in, times_in, - warn_times) - n_epochs = epoch_idx - - # compute final connectivity scores - con = list() - patterns = list() - for conn_method in con_methods: - - # compute connectivity scores - conn_method.compute_con(indices_use, rank, n_epochs) - - # get the connectivity scores - this_con = conn_method.con_scores - this_patterns = conn_method.patterns - - if this_con.shape[0] != n_cons: - raise RuntimeError( - 'first dimension of connectivity scores does not match the ' - 'number of connections; please contact the mne-connectivity ' - 'developers') - if faverage: - if this_con.shape[1] != n_freqs: - raise RuntimeError( - 'second dimension of connectivity scores does not match ' - 'the number of frequencies; please contact the ' - 'mne-connectivity developers') - con_shape = (n_cons, n_bands) + this_con.shape[2:] - this_con_bands = np.empty(con_shape, dtype=this_con.dtype) - for band_idx in range(n_bands): - this_con_bands[:, band_idx] = np.mean( - this_con[:, freq_idx_bands[band_idx]], axis=1) - this_con = this_con_bands - - if this_patterns is not None: - patterns_shape = list(this_patterns.shape) - patterns_shape[3] = n_bands - this_patterns_bands = np.empty(patterns_shape, - dtype=this_patterns.dtype) - for band_idx in range(n_bands): - this_patterns_bands[:, :, :, band_idx] = np.mean( - this_patterns[:, :, :, freq_idx_bands[band_idx]], - axis=3) - this_patterns = this_patterns_bands - - con.append(this_con) - patterns.append(this_patterns) - - conn_list = _store_results( - con=con, patterns=patterns, method=method, freqs=freqs, - faverage=faverage, freqs_bands=freqs_bands, names=names, mode=mode, - indices=indices, n_epochs=n_epochs, times=times, n_tapers=n_tapers, - metadata=metadata, events=events, event_id=event_id, rank=rank, - gc_n_lags=gc_n_lags, n_signals=n_signals) - - return conn_list diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 54bfafa5..592291f0 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -9,11 +9,10 @@ from mne_connectivity import ( SpectralConnectivity, spectral_connectivity_epochs, - spectral_connectivity_epochs_multivariate, read_connectivity, spectral_connectivity_time) +from mne_connectivity.spectral.epochs import _CohEst, _get_n_epochs from mne_connectivity.spectral.epochs import ( - _get_n_epochs, _compute_freq_mask, _compute_freqs) -from mne_connectivity.spectral.epochs_bivariate import _CohEst + _compute_freq_mask, _compute_freqs) def create_test_dataset(sfreq, n_signals, n_epochs, n_times, tmin, tmax, @@ -448,7 +447,7 @@ def test_spectral_connectivity_epochs_multivariate(method): data = data.reshape(n_signals, n_epochs, n_times) data = np.transpose(data, [1, 0, 2]) - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, gc_n_lags=20) freqs = con.freqs @@ -474,17 +473,17 @@ def test_spectral_connectivity_epochs_multivariate(method): # check that target -> seed connectivity is low indices_ts = (indices[1], indices[0]) - con_ts = spectral_connectivity_epochs_multivariate( + con_ts = spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices_ts, sfreq=sfreq, gc_n_lags=20) assert con_ts.get_data()[0, gidx[0]:gidx[1]].mean() < lower_t # check that TRGC is positive (i.e. net seed -> target connectivity not # due to noise) - con_tr = spectral_connectivity_epochs_multivariate( + con_tr = spectral_connectivity_epochs( data, method='gc_tr', mode=mode, indices=indices, sfreq=sfreq, gc_n_lags=20) - con_ts_tr = spectral_connectivity_epochs_multivariate( + con_ts_tr = spectral_connectivity_epochs( data, method='gc_tr', mode=mode, indices=indices_ts, sfreq=sfreq, gc_n_lags=20) trgc = ((con.get_data() - con_ts.get_data()) - @@ -498,7 +497,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check all-to-all conn. computed for MIC/MIM when no indices given if method in ['mic', 'mim']: - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=None, sfreq=sfreq) assert con.indices is None assert con.n_nodes == n_signals @@ -507,7 +506,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check ragged indices padded correctly ragged_indices = (np.array([[0]]), np.array([[1, 2]])) - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=ragged_indices, sfreq=sfreq) assert np.all(np.array(con.indices) == np.array([np.array([[0, -1]]), np.array([[1, 2]])])) @@ -515,7 +514,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check shape of MIC patterns if method == 'mic': for mode in ['multitaper', 'cwt_morlet']: - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, fmin=10, fmax=25, cwt_freqs=np.arange(10, 25), faverage=True) @@ -536,7 +535,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check patterns averaged over freqs fmin = (5., 15.) fmax = (15., 30.) - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, fmin=fmin, fmax=fmax, faverage=True) assert np.shape(con.attrs["patterns"][0][0])[1] == len(fmin) @@ -544,7 +543,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check patterns shape matches input data, not rank rank = (np.array([1]), np.array([1])) - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=rank) assert (np.shape(con.attrs["patterns"][0][0])[0] == n_seeds) @@ -552,7 +551,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check patterns padded correctly ragged_indices = (np.array([[0]]), np.array([[1, 2]])) - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=ragged_indices, sfreq=sfreq) patterns = np.array(con.attrs["patterns"]) @@ -587,7 +586,7 @@ def test_multivariate_spectral_connectivity_epochs_regression(): sfreq = 100 indices = (np.array([[0, 1]]), np.array([[2, 3]])) methods = ['mic', 'mim', 'gc', 'gc_tr'] - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( data, method=methods, indices=indices, mode='multitaper', sfreq=sfreq, fskip=0, faverage=False, tmin=0, tmax=None, mt_bandwidth=4, mt_low_bias=True, mt_adaptive=False, gc_n_lags=20, @@ -595,9 +594,8 @@ def test_multivariate_spectral_connectivity_epochs_regression(): # should take the absolute of the MIC scores, as the MATLAB implementation # returns the absolute values. - mne_results = {this_con.method: this_con.get_data() for this_con in con} - mne_results["mic"] = np.abs(mne_results["mic"]) - + mne_results = {this_con.method: np.abs(this_con.get_data()) + for this_con in con} matlab_results = pd.read_pickle( os.path.join(fpath, 'data', 'example_multivariate_matlab_results.pkl')) for method in methods: @@ -622,29 +620,40 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): with pytest.raises(TypeError, match='multivariate indices must contain array-likes'): non_nested_indices = (np.array([0, 1]), np.array([2, 3])) - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( data, method=method, mode=mode, indices=non_nested_indices, - sfreq=sfreq, cwt_freqs=cwt_freqs, gc_n_lags=10) + sfreq=sfreq, gc_n_lags=10) # check bad indices with repeated channels caught with pytest.raises(ValueError, match='multivariate indices cannot contain repeated'): repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( data, method=method, mode=mode, indices=repeated_indices, - sfreq=sfreq, cwt_freqs=cwt_freqs, gc_n_lags=10) + sfreq=sfreq, gc_n_lags=10) + + # check mixed methods caught + with pytest.raises(ValueError, + match='bivariate and multivariate connectivity'): + if isinstance(method, str): + mixed_methods = [method, 'coh'] + elif isinstance(method, list): + mixed_methods = [*method, 'coh'] + spectral_connectivity_epochs(data, method=mixed_methods, mode=mode, + indices=indices, sfreq=sfreq, + cwt_freqs=cwt_freqs) # check bad rank args caught too_low_rank = (np.array([0]), np.array([0])) with pytest.raises(ValueError, match='ranks for seeds and targets must be'): - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=too_low_rank, cwt_freqs=cwt_freqs) too_high_rank = (np.array([3]), np.array([3])) with pytest.raises(ValueError, match='ranks for seeds and targets must be'): - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=too_high_rank, cwt_freqs=cwt_freqs) @@ -655,7 +664,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): assert np.all(np.linalg.matrix_rank(bad_data[:, (0, 1), :]) == 1) assert np.all(np.linalg.matrix_rank(bad_data[:, (2, 3), :]) == 1) if isinstance(method, str): - rank_con = spectral_connectivity_epochs_multivariate( + rank_con = spectral_connectivity_epochs( bad_data, method=method, mode=mode, indices=indices, sfreq=sfreq, gc_n_lags=10, cwt_freqs=cwt_freqs) assert rank_con.attrs["rank"] == (np.array([1]), np.array([1])) @@ -664,7 +673,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): # check rank-deficient transformation matrix caught with pytest.raises(RuntimeError, match='the transformation matrix'): - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( bad_data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=(np.array([2]), np.array([2])), cwt_freqs=cwt_freqs) @@ -675,36 +684,37 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): frange = (5, 10) n_lags = 200 # will be far too high with pytest.raises(ValueError, match='the number of lags'): - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, fmin=frange[0], fmax=frange[1], gc_n_lags=n_lags, cwt_freqs=cwt_freqs) # check no indices caught with pytest.raises(ValueError, match='indices must be specified'): - spectral_connectivity_epochs_multivariate( - data, method=method, mode=mode, indices=None, sfreq=sfreq, - cwt_freqs=cwt_freqs) + spectral_connectivity_epochs(data, method=method, mode=mode, + indices=None, sfreq=sfreq, + cwt_freqs=cwt_freqs) # check intersecting indices caught bad_indices = (np.array([[0, 1]]), np.array([[0, 2]])) with pytest.raises(ValueError, match='seed and target indices must not intersect'): - spectral_connectivity_epochs_multivariate( - data, method=method, mode=mode, indices=bad_indices, - sfreq=sfreq, cwt_freqs=cwt_freqs) + spectral_connectivity_epochs(data, method=method, mode=mode, + indices=bad_indices, sfreq=sfreq, + cwt_freqs=cwt_freqs) # check bad fmin/fmax caught with pytest.raises(ValueError, match='computing Granger causality on multiple'): - spectral_connectivity_epochs_multivariate( - data, method=method, mode=mode, indices=indices, sfreq=sfreq, - fmin=(10., 15.), fmax=(15., 20.), cwt_freqs=cwt_freqs) + spectral_connectivity_epochs(data, method=method, mode=mode, + indices=indices, sfreq=sfreq, + fmin=(10., 15.), fmax=(15., 20.), + cwt_freqs=cwt_freqs) # check rank-deficient autocovariance caught with pytest.raises(RuntimeError, match='the autocovariance matrix is singular'): - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( bad_data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=(np.array([2]), np.array([2])), cwt_freqs=cwt_freqs) @@ -721,7 +731,7 @@ def test_multivar_spectral_connectivity_parallel(method): data = rng.randn(n_epochs, n_signals, n_times) indices = (np.array([[0, 1]]), np.array([[2, 3]])) - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( data, method=method, mode="multitaper", indices=indices, sfreq=sfreq, gc_n_lags=10, n_jobs=2) spectral_connectivity_time( @@ -751,13 +761,12 @@ def test_multivar_spectral_connectivity_flipped_indices(): # we test on GC since this is a directed connectivity measure method = 'gc' - con_st = spectral_connectivity_epochs_multivariate( # seed -> target + con_st = spectral_connectivity_epochs( # seed -> target data, method=method, indices=indices, sfreq=sfreq, gc_n_lags=10) - con_ts = spectral_connectivity_epochs_multivariate( # target -> seed + con_ts = spectral_connectivity_epochs( # target -> seed data, method=method, indices=flipped_indices, sfreq=sfreq, gc_n_lags=10) - con_st_ts = spectral_connectivity_epochs_multivariate( - # seed -> target; target -> seed + con_st_ts = spectral_connectivity_epochs( # seed -> target; target -> seed data, method=method, indices=concat_indices, sfreq=sfreq, gc_n_lags=10) assert not np.all(con_st.get_data() == con_ts.get_data()) assert np.all(con_st.get_data()[0] == con_st_ts.get_data()[0]) @@ -1289,7 +1298,7 @@ def test_multivar_save_load(tmp_path): non_ragged_indices = (np.array([[0, 1]]), np.array([[2, 3]])) ragged_indices = (np.array([[0, 1]]), np.array([[2]])) for indices in [non_ragged_indices, ragged_indices]: - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( epochs, method=['mic', 'mim', 'gc', 'gc_tr'], indices=indices, sfreq=sfreq, fmin=10, fmax=30) for this_con in con: @@ -1306,9 +1315,12 @@ def test_multivar_save_load(tmp_path): assert a == b -@pytest.mark.parametrize("method", ['coh', 'plv', 'pli', 'wpli', 'ciplv']) +@pytest.mark.parametrize("method", ["coh", "plv", "pli", "wpli", "ciplv", + "mic", "mim"]) @pytest.mark.parametrize("indices", [None, - (np.array([0, 1]), np.array([2, 3]))]) + (np.array([0, 1]), np.array([2, 3])), + (np.array([[0, 1]]), np.array([[2, 3]])) + ]) def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): """Test that indices values and type is maintained after saving. @@ -1325,6 +1337,14 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): freqs = np.arange(10, 31) tmp_file = os.path.join(tmp_path, "foo_mvc.nc") + # mutlivariate and bivariate methods require the right indices shape + if method in ["mic", "mim"]: + if indices is not None and indices[0].ndim == 1: + pytest.skip() + else: + if indices is not None and indices[0].ndim == 2: + pytest.skip() + # test the pair of method and indices defined to check the output indices con_epochs = spectral_connectivity_epochs( epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30 @@ -1346,53 +1366,3 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): assert np.all(np.array(con.indices) == np.array(read_con.indices)) else: assert con.indices is None and read_con.indices is None - - -@pytest.mark.parametrize("method", ['mic', 'mim', 'gc', 'gc_tr']) -@pytest.mark.parametrize("indices", [None, - (np.array([[0, 1]]), np.array([[2, 3]]))]) -def test_multivar_spectral_connectivity_indices_roundtrip_io( - tmp_path, method, indices -): - """Test that indices values and type is maintained after saving. - - If `indices` is None, `indices` in the returned connectivity object should - be None, otherwise, `indices` should be a tuple. The type of `indices` and - its values should be retained after saving and reloading. - """ - rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq = 5, 4, 200, 100.0 - data = rng.randn(n_epochs, n_chs, n_times) - info = create_info(n_chs, sfreq, "eeg") - tmin = -1 - epochs = EpochsArray(data, info, tmin=tmin) - freqs = np.arange(10, 31) - tmp_file = os.path.join(tmp_path, "foo_mvc.nc") - - # test the pair of method and indices defined to check the output indices - if indices is None and method in ['gc', 'gc_tr']: - # indicesmust be specified for GC - pytest.skip() - - con_epochs = spectral_connectivity_epochs_multivariate( - epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30, - gc_n_lags=10 - ) - con_time = spectral_connectivity_time( - epochs, freqs, method=method, indices=indices, sfreq=sfreq, - gc_n_lags=10 - ) - - for con in [con_epochs, con_time]: - con.save(tmp_file) - read_con = read_connectivity(tmp_file) - - if indices is not None: - # check indices of same type (tuples) - assert isinstance(con.indices, tuple) and isinstance( - read_con.indices, tuple - ) - # check indices have same values - assert np.all(np.array(con.indices) == np.array(read_con.indices)) - else: - assert con.indices is None and read_con.indices is None diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index d0059ace..3798f699 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -13,9 +13,8 @@ from mne.utils import (logger, verbose) from ..base import (SpectralConnectivity, EpochSpectralConnectivity) -from .epochs import _compute_freq_mask -from .epochs_multivariate import (_MICEst, _MIMEst, _GCEst, _GCTREst, - _check_rank_input) +from .epochs import (_MICEst, _MIMEst, _GCEst, _GCTREst, _compute_freq_mask, + _check_rank_input) from .smooth import _create_kernel, _smooth_spectra from ..utils import check_indices, check_multivariate_indices, fill_doc From 30f6c05119385ecbb9f7c3cb79397553da51003a Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 2 Nov 2023 13:02:00 +0100 Subject: [PATCH 30/40] switched to masked indices for multivariate conn --- doc/api.rst | 1 - examples/handling_ragged_arrays.py | 41 +++--- mne_connectivity/__init__.py | 4 +- mne_connectivity/io.py | 6 +- mne_connectivity/spectral/epochs.py | 58 ++++---- .../spectral/tests/test_spectral.py | 69 +++++++-- mne_connectivity/spectral/time.py | 45 +++--- mne_connectivity/tests/test_utils.py | 65 +++++++-- mne_connectivity/utils/__init__.py | 2 +- mne_connectivity/utils/utils.py | 132 ++++++++++-------- 10 files changed, 256 insertions(+), 167 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index c91f9c02..300601b4 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -75,7 +75,6 @@ Post-processing on connectivity seed_target_indices seed_target_multivariate_indices check_indices - check_multivariate_indices select_order Visualization functions diff --git a/examples/handling_ragged_arrays.py b/examples/handling_ragged_arrays.py index 67f06a4e..8076b9e2 100644 --- a/examples/handling_ragged_arrays.py +++ b/examples/handling_ragged_arrays.py @@ -62,8 +62,8 @@ # ragged_indices = (np.array([[0, 1 ], [0, 1, 2, 3]], dtype='object'), # np.array([[2, 3, 4], [4 ]], dtype='object')) # -# **N.B. Note that when forming ragged arrays in NumPy, dtype='object' must be -# specified.** +# **N.B. Note that since NumPy v1.19.0, dtype='object' must be specified when +# forming ragged arrays.** # # Just as for bivariate connectivity, the length of ``indices[0]`` and # ``indices[1]`` is equal (i.e. the number of connections), however information @@ -71,19 +71,16 @@ # array. Importantly, these indices are ragged, as the first connection will be # computed between 2 seed and 3 target channels, and the second connection # between 4 seed and 1 target channel. The connectivity functions will -# recognise the indices as being ragged, and pad them accordingly to make them -# easier to work with and compatible with the h5netcdf saving engine. The known -# value used to pad the arrays is ``-1``, an invalid channel index. The above -# indices would be padded to:: +# recognise the indices as being ragged, and pad them to a 'full' array by +# adding placeholder values which are masked accordingly. This makes the +# indices easier to work with, and also compatible with the engine used to save +# connectivity objects. For example, the above indices would become:: # -# padded_indices = (np.array([[0, 1, -1, -1], [0, 1, 2, 3]]), -# np.array([[2, 3, 4, -1], [4, -1, -1, -1]])) +# padded_indices = (np.array([[0, 1, --, --], [0, 1, 2, 3]]), +# np.array([[2, 3, 4, --], [4, --, --, --]])) # -# These indices are what is stored in the connectivity object, and is also the -# format of indices returned from the helper functions -# :func:`~mne_connectivity.check_multivariate_indices` and -# :func:`~mne_connectivity.seed_target_multivariate_indices`. It is also -# possible to pass the padded indices to the connectivity functions directly. +# where ``--`` are masked entries. These indices are what is stored in the +# returned connectivity objects. # # For the connectivity results themselves, the methods available in # MNE-Connectivity combine information across the different channels into a @@ -116,13 +113,13 @@ n_freqs = con.get_data().shape[-1] n_cons = len(ragged_indices[0]) max_n_chans = max( - [len(inds) for inds in ([*ragged_indices[0], *ragged_indices[1]])]) + len(inds) for inds in ([*ragged_indices[0], *ragged_indices[1]])) -# show that the padded indices entries are all -1 -assert np.count_nonzero(padded_indices[0][0] == -1) == 2 # 2 padded channels -assert np.count_nonzero(padded_indices[1][0] == -1) == 1 # 1 padded channels -assert np.count_nonzero(padded_indices[0][1] == -1) == 0 # 0 padded channels -assert np.count_nonzero(padded_indices[1][1] == -1) == 3 # 3 padded channels +# show that the padded indices entries are masked +assert np.sum(padded_indices[0][0].mask) == 2 # 2 padded channels +assert np.sum(padded_indices[1][0].mask) == 1 # 1 padded channels +assert np.sum(padded_indices[0][1].mask) == 0 # 0 padded channels +assert np.sum(padded_indices[1][1].mask) == 3 # 3 padded channels # patterns have shape [seeds/targets x cons x max channels x freqs (x times)] assert patterns.shape == (2, n_cons, max_n_chans, n_freqs) @@ -137,11 +134,11 @@ seed_patterns_con1 = patterns[0, 0, :len(ragged_indices[0][0])] target_patterns_con1 = patterns[1, 0, :len(ragged_indices[1][0])] -# extract patterns for second connection using the padded indices (pad = -1) +# extract patterns for second connection using the padded, masked indices seed_patterns_con2 = ( - patterns[0, 1, :np.count_nonzero(padded_indices[0][1] != -1)]) + patterns[0, 1, :padded_indices[0][1].count()]) target_patterns_con2 = ( - patterns[1, 1, :np.count_nonzero(padded_indices[1][1] != -1)]) + patterns[1, 1, :padded_indices[1][1].count()]) # show that shapes of patterns are correct assert seed_patterns_con1.shape == (2, n_freqs) # channels (0, 1) diff --git a/mne_connectivity/__init__.py b/mne_connectivity/__init__.py index c2f03a6c..43fe0793 100644 --- a/mne_connectivity/__init__.py +++ b/mne_connectivity/__init__.py @@ -17,5 +17,5 @@ from .io import read_connectivity from .spectral import spectral_connectivity_time, spectral_connectivity_epochs from .vector_ar import vector_auto_regression, select_order -from .utils import (check_indices, check_multivariate_indices, degree, - seed_target_indices, seed_target_multivariate_indices) +from .utils import (check_indices, degree, seed_target_indices, + seed_target_multivariate_indices) diff --git a/mne_connectivity/io.py b/mne_connectivity/io.py index 63aa3501..e8d9b916 100644 --- a/mne_connectivity/io.py +++ b/mne_connectivity/io.py @@ -53,9 +53,11 @@ def _xarray_to_conn(array, cls_func): event_id = dict(zip(event_id_keys, event_id_vals)) array.attrs['event_id'] = event_id - # convert indices numpy arrays to a tuple of arrays + # convert indices numpy arrays to a tuple of masked arrays + # (only multivariate connectivity indices saved as arrays) if isinstance(array.attrs['indices'], np.ndarray): - array.attrs['indices'] = tuple(array.attrs['indices']) + array.attrs['indices'] = tuple( + np.ma.masked_values(array.attrs['indices'], -1)) # create the connectivity class conn = cls_func( diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 7ad551c9..914f8531 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -24,7 +24,7 @@ ProgressBar, _arange_div, _check_option, _time_mask, logger, warn, verbose) from ..base import (SpectralConnectivity, SpectroTemporalConnectivity) -from ..utils import fill_doc, check_indices, check_multivariate_indices +from ..utils import fill_doc, check_indices, _check_multivariate_indices def _compute_freqs(n_times, sfreq, cwt_freqs, mode): @@ -107,17 +107,24 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, logger.info('using all indices for multivariate connectivity') indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], np.arange(n_signals, dtype=int)[np.newaxis, :]) + indices_use = np.ma.masked_array(indices_use, + mask=False, fill_value=-1) else: logger.info('only using indices for lower-triangular matrix') # only compute r for lower-triangular region indices_use = np.tril_indices(n_signals, -1) else: if multivariate_con: - indices_use = check_multivariate_indices(indices) # pad with -1 + # mask indices + indices_use = _check_multivariate_indices(indices, n_signals) + indices_use = np.ma.concatenate([inds[np.newaxis] for inds in + indices_use]) + np.ma.set_fill_value(indices_use, -1) # else 99999 after concat. if any(this_method in _gc_methods for this_method in method): - for seed, target in zip(indices[0], indices[1]): - intersection = np.intersect1d(seed, target) - if np.any(intersection != -1): # ignore padded entries + for seed, target in zip(indices_use[0], indices_use[1]): + intersection = np.intersect1d(seed.compressed(), + target.compressed()) + if intersection.size > 0: raise ValueError( 'seed and target indices must not intersect when ' 'computing Granger causality') @@ -191,7 +198,7 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, def _check_rank_input(rank, data, indices): """Check the rank argument is appropriate and compute rank if missing.""" - sv_tol = 1e-10 # tolerance for non-zero singular val (rel to largest) + sv_tol = 1e-10 # tolerance for non-zero singular val (rel. to largest) if rank is None: rank = np.zeros((2, len(indices[0])), dtype=int) @@ -200,16 +207,10 @@ def _check_rank_input(rank, data, indices): else: data_arr = data - # XXX: Unpadding of arrays after already padding them is perhaps not so - # efficient. However, we need to remove the padded values to - # ensure only the correct channels are indexed, and having two - # versions of indices is a bit messy currently. A candidate for - # refactoring to simplify code. - for group_i in range(2): # seeds and targets for con_i, con_idcs in enumerate(indices[group_i]): - con_idcs = con_idcs[con_idcs != -1] # -1 is padded value - s = np.linalg.svd(data_arr[:, con_idcs], compute_uv=False) + s = np.linalg.svd(data_arr[:, con_idcs.compressed()], + compute_uv=False) rank[group_i][con_i] = np.min( [np.count_nonzero(epoch >= epoch[0] * sv_tol) for epoch in s]) @@ -476,8 +477,8 @@ def compute_con(self, indices, ranks, n_epochs=1): indices[0], indices[1], ranks[0], ranks[1]): self._log_connection_number(con_i) - seed_idcs = seed_idcs[seed_idcs != -1] - target_idcs = target_idcs[target_idcs != -1] + seed_idcs = seed_idcs.compressed() + target_idcs = target_idcs.compressed() con_idcs = [*seed_idcs, *target_idcs] C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] @@ -950,8 +951,8 @@ def compute_con(self, indices, ranks, n_epochs=1): indices[0], indices[1], ranks[0], ranks[1]): self._log_connection_number(con_i) - seed_idcs = seed_idcs[seed_idcs != -1] - target_idcs = target_idcs[target_idcs != -1] + seed_idcs = seed_idcs.compressed() + target_idcs = target_idcs.compressed() con_idcs = [*seed_idcs, *target_idcs] C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] @@ -1907,7 +1908,8 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, # make sure padded indices are stored in the connectivity object if multivariate_con and indices is not None: - indices = tuple(np.array(indices_use)) # create a copy + # create a copy + indices = (indices_use[0].copy(), indices_use[1].copy()) # get the window function, wavelets, etc for different modes (spectral_params, mt_adaptive, n_times_spectrum, @@ -1919,20 +1921,12 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, # unique signals for which we actually need to compute PSD etc. if multivariate_con: - sig_idx = np.unique(np.concatenate(np.concatenate( - indices_use))) - sig_idx = sig_idx[sig_idx != -1] + sig_idx = np.unique(indices_use.compressed()) remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(sig_idx)} - remapping[-1] = -1 - remapped_inds = (indices_use[0].copy(), indices_use[1].copy()) - con_i = 0 - for seed, target in zip(indices_use[0], indices_use[1]): - remapped_inds[0][con_i] = np.array([ - remapping[idx] for idx in seed]) - remapped_inds[1][con_i] = np.array([ - remapping[idx] for idx in target]) - con_i += 1 - remapped_sig = [remapping[idx] for idx in sig_idx] + remapped_inds = indices_use.copy() + for idx in sig_idx: + remapped_inds[indices_use == idx] = remapping[idx] + remapped_sig = np.unique(remapped_inds.compressed()) else: sig_idx = np.unique(np.r_[indices_use[0], indices_use[1]]) n_signals_use = len(sig_idx) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 592291f0..a514382b 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -1315,12 +1315,9 @@ def test_multivar_save_load(tmp_path): assert a == b -@pytest.mark.parametrize("method", ["coh", "plv", "pli", "wpli", "ciplv", - "mic", "mim"]) +@pytest.mark.parametrize("method", ['coh', 'plv', 'pli', 'wpli', 'ciplv']) @pytest.mark.parametrize("indices", [None, - (np.array([0, 1]), np.array([2, 3])), - (np.array([[0, 1]]), np.array([[2, 3]])) - ]) + (np.array([0, 1]), np.array([2, 3]))]) def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): """Test that indices values and type is maintained after saving. @@ -1337,14 +1334,6 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): freqs = np.arange(10, 31) tmp_file = os.path.join(tmp_path, "foo_mvc.nc") - # mutlivariate and bivariate methods require the right indices shape - if method in ["mic", "mim"]: - if indices is not None and indices[0].ndim == 1: - pytest.skip() - else: - if indices is not None and indices[0].ndim == 2: - pytest.skip() - # test the pair of method and indices defined to check the output indices con_epochs = spectral_connectivity_epochs( epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30 @@ -1366,3 +1355,57 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): assert np.all(np.array(con.indices) == np.array(read_con.indices)) else: assert con.indices is None and read_con.indices is None + + +@pytest.mark.parametrize("method", ['mic', 'mim', 'gc', 'gc_tr']) +@pytest.mark.parametrize("indices", [None, + (np.array([[0, 1]]), np.array([[2, 3]]))]) +def test_multivar_spectral_connectivity_indices_roundtrip_io( + tmp_path, method, indices +): + """Test that indices values and type is maintained after saving. + + If `indices` is None, `indices` in the returned connectivity object should + be None, otherwise, `indices` should be a tuple. The type of `indices` and + its values should be retained after saving and reloading. + """ + rng = np.random.RandomState(0) + n_epochs, n_chs, n_times, sfreq = 5, 4, 200, 100.0 + data = rng.randn(n_epochs, n_chs, n_times) + info = create_info(n_chs, sfreq, "eeg") + tmin = -1 + epochs = EpochsArray(data, info, tmin=tmin) + freqs = np.arange(10, 31) + tmp_file = os.path.join(tmp_path, "foo_mvc.nc") + + # test the pair of method and indices defined to check the output indices + if indices is None and method in ['gc', 'gc_tr']: + # indicesmust be specified for GC + pytest.skip() + + con_epochs = spectral_connectivity_epochs( + epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30, + gc_n_lags=10 + ) + con_time = spectral_connectivity_time( + epochs, freqs, method=method, indices=indices, sfreq=sfreq, + gc_n_lags=10 + ) + + for con in [con_epochs, con_time]: + con.save(tmp_file) + read_con = read_connectivity(tmp_file) + + if indices is not None: + # check indices of same type (tuples) + assert isinstance(con.indices, tuple) and isinstance( + read_con.indices, tuple + ) + # check indices are masked + assert all([np.ma.isMA(inds) for inds in con.indices] and + [np.ma.isMA(inds) for inds in read_con.indices]) + # check indices have same values + assert np.all([con_inds == read_inds for con_inds, read_inds in + zip(con.indices, read_con.indices)]) + else: + assert con.indices is None and read_con.indices is None diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 3798f699..4ac3f161 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -16,7 +16,7 @@ from .epochs import (_MICEst, _MIMEst, _GCEst, _GCTREst, _compute_freq_mask, _check_rank_input) from .smooth import _create_kernel, _smooth_spectra -from ..utils import check_indices, check_multivariate_indices, fill_doc +from ..utils import check_indices, _check_multivariate_indices, fill_doc _multivariate_methods = ['mic', 'mim', 'gc', 'gc_tr'] @@ -407,46 +407,51 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, 'indices must be specified when computing Granger ' 'causality, as all-to-all connectivity is not supported') logger.info('using all indices for multivariate connectivity') - indices_use = (np.array([np.arange(n_signals, dtype=np.int32)]), - np.array([np.arange(n_signals, dtype=np.int32)])) + indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], + np.arange(n_signals, dtype=int)[np.newaxis, :]) + indices_use = np.ma.masked_array(indices_use, + mask=False, fill_value=-1) else: logger.info('only using indices for lower-triangular matrix') indices_use = np.tril_indices(n_signals, k=-1) else: if multivariate_con: - indices_use = check_multivariate_indices(indices) # pad with -1 + # mask indices + indices_use = _check_multivariate_indices(indices, n_signals) + indices_use = np.ma.concatenate([inds[np.newaxis] for inds in + indices_use]) + np.ma.set_fill_value(indices_use, -1) # else 99999 after concat. if any(this_method in _gc_methods for this_method in method): - for seed, target in zip(indices[0], indices[1]): - intersection = np.intersect1d(seed, target) - if np.any(intersection != -1): # ignore padded entries + for seed, target in zip(indices_use[0], indices_use[1]): + intersection = np.intersect1d(seed.compressed(), + target.compressed()) + if intersection.size > 0: raise ValueError( 'seed and target indices must not intersect when ' 'computing Granger causality') # make sure padded indices are stored in the connectivity object - indices = tuple(np.array(indices_use)) # create a copy + # create a copy + indices = (indices_use[0].copy(), indices_use[1].copy()) else: indices_use = check_indices(indices) - # create copies of indices_use for independent manipulation - source_idx = np.array(indices_use[0]) - target_idx = np.array(indices_use[1]) - n_cons = len(source_idx) + n_cons = len(indices_use[0]) # unique signals for which we actually need to compute the CSD of if multivariate_con: - signals_use = np.unique(np.concatenate(np.concatenate(indices_use))) - signals_use = signals_use[signals_use != -1] + signals_use = np.unique(indices_use.compressed()) remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(signals_use)} - remapping[-1] = -1 + remapped_inds = indices_use.copy() # multivariate functions expect seed/target remapping - con_i = 0 - for seed, target in zip(indices_use[0], indices_use[1]): - source_idx[con_i] = np.array([remapping[idx] for idx in seed]) - target_idx[con_i] = np.array([remapping[idx] for idx in target]) - con_i += 1 + for idx in signals_use: + remapped_inds[indices_use == idx] = remapping[idx] + source_idx = remapped_inds[0] + target_idx = remapped_inds[1] max_n_channels = len(indices_use[0][0]) else: # no indices remapping required for bivariate functions signals_use = np.unique(np.r_[indices_use[0], indices_use[1]]) + source_idx = indices_use[0].copy() + target_idx = indices_use[1].copy() max_n_channels = len(indices_use[0]) # check rank input and compute data ranks if necessary diff --git a/mne_connectivity/tests/test_utils.py b/mne_connectivity/tests/test_utils.py index 1e5822eb..caead679 100644 --- a/mne_connectivity/tests/test_utils.py +++ b/mne_connectivity/tests/test_utils.py @@ -4,7 +4,7 @@ from mne_connectivity import Connectivity from mne_connectivity.utils import (degree, check_indices, - check_multivariate_indices, + _check_multivariate_indices, seed_target_indices, seed_target_multivariate_indices) @@ -34,14 +34,22 @@ def test_seed_target_indices(): seeds = [[0, 1]] targets = [[2, 3], [3, 4]] indices = seed_target_multivariate_indices(seeds, targets) - assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]), - np.array([[2, 3], [3, 4]]))) + match_indices = (np.array([[0, 1], [0, 1]], dtype=object), + np.array([[2, 3], [3, 4]], dtype=object)) + for type_i in range(2): + for con_i in range(len(indices[0])): + assert np.all(indices[type_i][con_i] == + match_indices[type_i][con_i]) # ragged indices seeds = [[0, 1]] targets = [[2, 3, 4], [4]] indices = seed_target_multivariate_indices(seeds, targets) - assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]), - np.array([[2, 3, 4], [4, -1, -1]]))) + match_indices = (np.array([[0, 1], [0, 1]], dtype=object), + np.array([[2, 3, 4], [4]], dtype=object)) + for type_i in range(2): + for con_i in range(len(indices[0])): + assert np.all(indices[type_i][con_i] == + match_indices[type_i][con_i]) # test error catching # non-array-like seeds/targets with pytest.raises(TypeError, @@ -58,7 +66,7 @@ def test_seed_target_indices(): def test_check_indices(): - """Test indices checking functions.""" + """Test check_indices function.""" # bivariate indices # test error catching with pytest.raises(ValueError, @@ -77,39 +85,68 @@ def test_check_indices(): nested_indices = ([[0]], [[1]]) check_indices(nested_indices) - # multivariate indices + +def test_check_multivariate_indices(): + """Test _check_multivariate_indices function.""" + n_signals = 5 + mask_value = -1 # non-ragged indices seeds = [[0, 1], [0, 1]] targets = [[2, 3], [3, 4]] - indices = check_multivariate_indices((seeds, targets)) + indices = _check_multivariate_indices((seeds, targets), n_signals) + assert all(np.ma.isMA(inds) for inds in indices) + assert all(inds.fill_value == mask_value for inds in indices) + assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]), + np.array([[2, 3], [3, 4]]))) + # non-ragged indices with negative values + seeds = [[0, 1], [0, 1]] + targets = [[2, 3], [3, -1]] + indices = _check_multivariate_indices((seeds, targets), n_signals) + assert all(np.ma.isMA(inds) for inds in indices) + assert all(inds.fill_value == mask_value for inds in indices) assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]), np.array([[2, 3], [3, 4]]))) # ragged indices seeds = [[0, 1], [0, 1]] targets = [[2, 3, 4], [4]] - indices = check_multivariate_indices((seeds, targets)) + indices = _check_multivariate_indices((seeds, targets), n_signals) + assert all(np.ma.isMA(inds) for inds in indices) + assert all(inds.fill_value == mask_value for inds in indices) assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]), np.array([[2, 3, 4], [4, -1, -1]]))) + # ragged indices with negative values + seeds = [[0, 1], [0, 1]] + targets = [[2, 3, 4], [-1]] + indices = _check_multivariate_indices((seeds, targets), n_signals) + assert all(np.ma.isMA(inds) for inds in indices) + assert all(inds.fill_value == mask_value for inds in indices) + assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]), + np.array([[2, 3, 4], [4, -1, -1]]))) + # test error catching with pytest.raises(ValueError, match='indices must be a tuple of length 2'): non_tuple_indices = [np.array([0, 1]), np.array([2, 3])] - check_multivariate_indices(non_tuple_indices) + _check_multivariate_indices(non_tuple_indices, n_signals) with pytest.raises(ValueError, match='indices must be a tuple of length 2'): non_len2_indices = (np.array([0]), np.array([1]), np.array([2])) - check_multivariate_indices(non_len2_indices) + _check_multivariate_indices(non_len2_indices, n_signals) with pytest.raises(ValueError, match='index arrays indices'): non_equal_len_indices = (np.array([[0]]), np.array([[1], [2]])) - check_multivariate_indices(non_equal_len_indices) + _check_multivariate_indices(non_equal_len_indices, n_signals) with pytest.raises(TypeError, match='multivariate indices must contain array-likes'): non_nested_indices = (np.array([0, 1]), np.array([2, 3])) - check_multivariate_indices(non_nested_indices) + _check_multivariate_indices(non_nested_indices, n_signals) with pytest.raises(ValueError, match='multivariate indices cannot contain repeated'): repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) - check_multivariate_indices(repeated_indices) + _check_multivariate_indices(repeated_indices, n_signals) + with pytest.raises(ValueError, + match='a negative channel index is not present in the'): + missing_chan_indices = (np.array([[0, 1]]), np.array([[2, -5]])) + _check_multivariate_indices(missing_chan_indices, n_signals) def test_degree(): diff --git a/mne_connectivity/utils/__init__.py b/mne_connectivity/utils/__init__.py index 0df454a4..171711b8 100644 --- a/mne_connectivity/utils/__init__.py +++ b/mne_connectivity/utils/__init__.py @@ -1,4 +1,4 @@ from .docs import fill_doc -from .utils import (check_indices, check_multivariate_indices, degree, +from .utils import (check_indices, _check_multivariate_indices, degree, seed_target_indices, seed_target_multivariate_indices, parallel_loop, _prepare_xarray_mne_data_structures) diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index 16f56f09..2cab0007 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -83,18 +83,22 @@ def check_indices(indices): return indices -def check_multivariate_indices(indices): - """Check indices parameter for multivariate connectivity and pad it. +def _check_multivariate_indices(indices, n_chans): + """Check indices parameter for multivariate connectivity and mask it. Parameters ---------- indices : tuple of array of array of int, shape (2, n_cons, variable) Tuple containing index sets. + n_chans : int + The number of channels in the data. Used when converting negative + indices to positive indices. + Returns ------- indices : tuple of array of array of int, shape of (2, n_cons, max_n_chans) - The indices padded with the invalid channel index ``-1``. + The indices as a masked array. Notes ----- @@ -108,26 +112,35 @@ def check_multivariate_indices(indices): connection must be unique. If the seed and target indices are given as lists or tuples, they will be - converted to numpy arrays. In case the number of channels differs across + converted to numpy arrays. Because the number of channels can differ across connections or between the seeds and targets for a given connection (i.e. - ragged indices), the returned array will be padded with the invalid channel - index ``-1`` according to the maximum number of channels in the seed or - target of any one connection. E.g. the ragged indices of shape ``(2, - n_cons, variable)``:: + ragged/jagged indices), the returned array will be padded out to a 'full' + array with an invalid index (``-1``) according to the maximum number of + channels in the seed or target of any one connection. These invalid + entries are then masked and returned as numpy masked arrays. E.g. the + ragged indices of shape ``(2, n_cons, variable)``:: indices = ([[0, 1], [0, 1 ]], # seeds [[2, 3], [4, 5, 6]]) # targets - would be returned as:: + would be padded to full arrays:: + + indices = ([[0, 1, -1], [0, 1, -1]], # seeds + [[2, 3, -1], [4, 5, 6]]) # targets + + to have shape ``(2, n_cons, max_n_chans)``, where ``max_n_chans = 3``. The + invalid entries are then masked:: - indices = (np.array([[0, 1, -1], [0, 1, -1]]), # seeds - np.array([[2, 3, -1], [4, 5, -1]])) # targets + indices = ([[0, 1, --], [0, 1, --]], # seeds + [[2, 3, --], [4, 5, 6]]) # targets - where the indices have been padded with ``-1`` to have shape ``(2, n_cons, - max_n_chans)``, where ``max_n_chans = 3``. More information on working with - multivariate indices and handling connections where the number of seeds and - targets are not equal can be found in the - :doc:`../auto_examples/handling_ragged_arrays` example. + In case "indices" contains negative values to index channels, these will be + converted to the corresponding positive-valued index before any masking is + applied. + + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. """ if not isinstance(indices, tuple) or len(indices) != 2: raise ValueError('indices must be a tuple of length 2') @@ -137,29 +150,45 @@ def check_multivariate_indices(indices): 'have the same length') n_cons = len(indices[0]) + invalid = -1 max_n_chans = 0 - for inds in ([*indices[0], *indices[1]]): - if not isinstance(inds, (np.ndarray, list, tuple)): - raise TypeError( - 'multivariate indices must contain array-likes of channel ' - 'indices for each seed and target') - if len(inds) != len(np.unique(inds)): - raise ValueError( - 'multivariate indices cannot contain repeated channels within ' - 'a seed or target') - max_n_chans = max(max_n_chans, len(inds)) + for group_idx, group in enumerate(indices): + for con_idx, con in enumerate(group): + if not isinstance(con, (np.ndarray, list, tuple)): + raise TypeError( + 'multivariate indices must contain array-likes of channel ' + 'indices for each seed and target') + con = np.array(con) + if len(con) != len(np.unique(con)): + raise ValueError( + 'multivariate indices cannot contain repeated channels ' + 'within a seed or target') + max_n_chans = max(max_n_chans, len(con)) + # convert negative to positive indices + for chan_idx, chan in enumerate(con): + if chan < 0: + if chan * -1 >= n_chans: + raise ValueError( + 'a negative channel index is not present in the ' + 'data' + ) + indices[group_idx][con_idx][chan_idx] = chan % n_chans # pad indices to avoid ragged arrays - padded_indices = (np.full((n_cons, max_n_chans), -1, dtype=np.int32), - np.full((n_cons, max_n_chans), -1, dtype=np.int32)) + padded_indices = (np.full((n_cons, max_n_chans), invalid, dtype=np.int32), + np.full((n_cons, max_n_chans), invalid, dtype=np.int32)) con_i = 0 for seed, target in zip(indices[0], indices[1]): padded_indices[0][con_i, :len(seed)] = seed padded_indices[1][con_i, :len(target)] = target con_i += 1 - return padded_indices + # mask invalid indices + masked_indices = (np.ma.masked_values(padded_indices[0], invalid), + np.ma.masked_values(padded_indices[1], invalid)) + + return masked_indices def seed_target_indices(seeds, targets): @@ -221,8 +250,8 @@ def seed_target_multivariate_indices(seeds, targets): Returns ------- - indices : tuple of array of array of int, shape (2, n_cons, max_n_chans) - The indices padded with the invalid channel index ``-1``. + indices : tuple of array of array of int, shape (2, n_cons, variable) + The indices as a numpy object array. Notes ----- @@ -232,12 +261,8 @@ def seed_target_multivariate_indices(seeds, targets): channels in the data. The length of indices for each connection do not need to be equal. Furthermore, all indices within a connection must be unique. - ``seeds`` and ``targets`` will be expanded such that connectivity will be - computed between each set of seeds and targets. In case the number of - channels differs across connections or between the seeds and targets for a - given connection (i.e. ragged indices), the returned array will be padded - with the invalid channel index ``-1`` according to the maximum number of - channels in the seed or target of any one connection. E.g. ``seeds`` and + Because the number of channels per connection can vary, the indices are + stored as numpy arrays with ``dtype=object``. E.g. ``seeds`` and ``targets``:: seeds = [[0]] @@ -245,8 +270,8 @@ def seed_target_multivariate_indices(seeds, targets): would be returned as:: - indices = (np.array([[0 ], [0 ]]), # seeds - np.array([[1, 2], [3, 4, 5]])) # targets + indices = (np.array([[0 ], [0 ]], dtype=object), # seeds + np.array([[1, 2], [3, 4, 5]], dtype=object)) # targets Even if the number of channels does not vary, the indices will still be stored as object arrays for compatibility. @@ -263,7 +288,6 @@ def seed_target_multivariate_indices(seeds, targets): ): raise TypeError('`seeds` and `targets` must be array-like') - n_chans = [] for inds in [*seeds, *targets]: if not isinstance(inds, array_like): raise TypeError( @@ -271,27 +295,15 @@ def seed_target_multivariate_indices(seeds, targets): if len(inds) != len(np.unique(inds)): raise ValueError( '`seeds` and `targets` cannot contain repeated channels') - n_chans.append(len(inds)) - max_n_chans = max(n_chans) - n_cons = len(seeds) * len(targets) - # pad indices to avoid ragged arrays - padded_seeds = np.full((len(seeds), max_n_chans), -1, dtype=np.int32) - padded_targets = np.full((len(targets), max_n_chans), -1, dtype=np.int32) - for con_i, seed in enumerate(seeds): - padded_seeds[con_i, :len(seed)] = seed - for con_i, target in enumerate(targets): - padded_targets[con_i, :len(target)] = target - - # create final indices - indices = (np.zeros((n_cons, max_n_chans), dtype=np.int32), - np.zeros((n_cons, max_n_chans), dtype=np.int32)) - con_i = 0 - for seed in padded_seeds: - for target in padded_targets: - indices[0][con_i] = seed - indices[1][con_i] = target - con_i += 1 + indices = [[], []] + for seed in seeds: + for target in targets: + indices[0].append(np.array(seed)) + indices[1].append(np.array(target)) + + indices = (np.array(indices[0], dtype=object), + np.array(indices[1], dtype=object)) return indices From 9cfad80a5777f05450b35d6f65621af058c0a506 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 2 Nov 2023 15:49:13 +0100 Subject: [PATCH 31/40] Revert "switch to masked arrays for indices" This reverts commit 921cf9db82fc617c0702fac3f56529ef54588320. --- examples/handling_ragged_arrays.py | 39 +++--- mne_connectivity/io.py | 6 +- .../spectral/epochs_multivariate.py | 53 ++++---- .../spectral/tests/test_spectral.py | 6 +- mne_connectivity/spectral/time.py | 42 +++--- mne_connectivity/tests/test_utils.py | 16 +-- mne_connectivity/utils/utils.py | 125 ++++++++---------- 7 files changed, 133 insertions(+), 154 deletions(-) diff --git a/examples/handling_ragged_arrays.py b/examples/handling_ragged_arrays.py index ecb4a9c4..2bc0dcc0 100644 --- a/examples/handling_ragged_arrays.py +++ b/examples/handling_ragged_arrays.py @@ -62,8 +62,8 @@ # ragged_indices = (np.array([[0, 1 ], [0, 1, 2, 3]], dtype='object'), # np.array([[2, 3, 4], [4 ]], dtype='object')) # -# **N.B. Note that since NumPy v1.19.0, dtype='object' must be specified when -# forming ragged arrays.** +# **N.B. Note that when forming ragged arrays in NumPy, dtype='object' must be +# specified.** # # Just as for bivariate connectivity, the length of ``indices[0]`` and # ``indices[1]`` is equal (i.e. the number of connections), however information @@ -71,16 +71,19 @@ # array. Importantly, these indices are ragged, as the first connection will be # computed between 2 seed and 3 target channels, and the second connection # between 4 seed and 1 target channel. The connectivity functions will -# recognise the indices as being ragged, and pad them to a 'full' array by -# adding placeholder values which are masked accordingly. This makes the -# indices easier to work with, and also compatible with the engine used to save -# connectivity objects. For example, the above indices would become:: +# recognise the indices as being ragged, and pad them accordingly to make them +# easier to work with and compatible with the h5netcdf saving engine. The known +# value used to pad the arrays is ``-1``, an invalid channel index. The above +# indices would be padded to:: # -# padded_indices = (np.array([[0, 1, --, --], [0, 1, 2, 3]]), -# np.array([[2, 3, 4, --], [4, --, --, --]])) +# padded_indices = (np.array([[0, 1, -1, -1], [0, 1, 2, 3]]), +# np.array([[2, 3, 4, -1], [4, -1, -1, -1]])) # -# where ``--`` are masked entries. These indices are what is stored in the -# returned connectivity objects. +# These indices are what is stored in the connectivity object, and is also the +# format of indices returned from the helper functions +# :func:`~mne_connectivity.check_multivariate_indices` and +# :func:`~mne_connectivity.seed_target_multivariate_indices`. It is also +# possible to pass the padded indices to the connectivity functions directly. # # For the connectivity results themselves, the methods available in # MNE-Connectivity combine information across the different channels into a @@ -115,11 +118,11 @@ max_n_chans = max( [len(inds) for inds in ([*ragged_indices[0], *ragged_indices[1]])]) -# show that the padded indices entries are masked -assert np.sum(padded_indices[0][0].mask) == 2 # 2 padded channels -assert np.sum(padded_indices[1][0].mask) == 1 # 1 padded channels -assert np.sum(padded_indices[0][1].mask) == 0 # 0 padded channels -assert np.sum(padded_indices[1][1].mask) == 3 # 3 padded channels +# show that the padded indices entries are all -1 +assert np.count_nonzero(padded_indices[0][0] == -1) == 2 # 2 padded channels +assert np.count_nonzero(padded_indices[1][0] == -1) == 1 # 1 padded channels +assert np.count_nonzero(padded_indices[0][1] == -1) == 0 # 0 padded channels +assert np.count_nonzero(padded_indices[1][1] == -1) == 3 # 3 padded channels # patterns have shape [seeds/targets x cons x max channels x freqs (x times)] assert patterns.shape == (2, n_cons, max_n_chans, n_freqs) @@ -134,11 +137,11 @@ seed_patterns_con1 = patterns[0, 0, :len(ragged_indices[0][0])] target_patterns_con1 = patterns[1, 0, :len(ragged_indices[1][0])] -# extract patterns for second connection using the padded, masked indices +# extract patterns for second connection using the padded indices (pad = -1) seed_patterns_con2 = ( - patterns[0, 1, :padded_indices[0][1].count()]) + patterns[0, 1, :np.count_nonzero(padded_indices[0][1] != -1)]) target_patterns_con2 = ( - patterns[1, 1, :padded_indices[1][1].count()]) + patterns[1, 1, :np.count_nonzero(padded_indices[1][1] != -1)]) # show that shapes of patterns are correct assert seed_patterns_con1.shape == (2, n_freqs) # channels (0, 1) diff --git a/mne_connectivity/io.py b/mne_connectivity/io.py index e8d9b916..63aa3501 100644 --- a/mne_connectivity/io.py +++ b/mne_connectivity/io.py @@ -53,11 +53,9 @@ def _xarray_to_conn(array, cls_func): event_id = dict(zip(event_id_keys, event_id_vals)) array.attrs['event_id'] = event_id - # convert indices numpy arrays to a tuple of masked arrays - # (only multivariate connectivity indices saved as arrays) + # convert indices numpy arrays to a tuple of arrays if isinstance(array.attrs['indices'], np.ndarray): - array.attrs['indices'] = tuple( - np.ma.masked_values(array.attrs['indices'], -1)) + array.attrs['indices'] = tuple(array.attrs['indices']) # create the connectivity class conn = cls_func( diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index db648456..28077adb 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -31,18 +31,12 @@ def _check_indices(indices, method, n_signals): logger.info('using all indices for multivariate connectivity') indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], np.arange(n_signals, dtype=int)[np.newaxis, :]) - indices_use = np.ma.masked_array(indices_use, - mask=False, fill_value=-1) else: - indices_use = check_multivariate_indices(indices) # mask indices - indices_use = np.ma.concatenate([inds[np.newaxis] for inds in - indices_use]) - np.ma.set_fill_value(indices_use, -1) # else 99999 after concat. + indices_use = check_multivariate_indices(indices) # pad with -1 if any(this_method in _gc_methods for this_method in method): - for seed, target in zip(indices_use[0], indices_use[1]): - intersection = np.intersect1d(seed.compressed(), - target.compressed()) - if intersection.size > 0: + for seed, target in zip(indices[0], indices[1]): + intersection = np.intersect1d(seed, target) + if np.any(intersection != -1): # ignore padded entries raise ValueError( 'seed and target indices must not intersect when ' 'computing Granger causality') @@ -65,10 +59,16 @@ def _check_rank_input(rank, data, indices): else: data_arr = data + # XXX: Unpadding of arrays after already padding them is perhaps not so + # efficient. However, we need to remove the padded values to + # ensure only the correct channels are indexed, and having two + # versions of indices is a bit messy currently. A candidate for + # refactoring to simplify code. + for group_i in range(2): # seeds and targets for con_i, con_idcs in enumerate(indices[group_i]): - s = np.linalg.svd(data_arr[:, con_idcs.compressed()], - compute_uv=False) + con_idcs = con_idcs[con_idcs != -1] # -1 is padded value + s = np.linalg.svd(data_arr[:, con_idcs], compute_uv=False) rank[group_i][con_i] = np.min( [np.count_nonzero(epoch >= epoch[0] * sv_tol) for epoch in s]) @@ -197,8 +197,8 @@ def compute_con(self, indices, ranks, n_epochs=1): indices[0], indices[1], ranks[0], ranks[1]): self._log_connection_number(con_i) - seed_idcs = seed_idcs.compressed() - target_idcs = target_idcs.compressed() + seed_idcs = seed_idcs[seed_idcs != -1] + target_idcs = target_idcs[target_idcs != -1] con_idcs = [*seed_idcs, *target_idcs] C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] @@ -432,8 +432,8 @@ def compute_con(self, indices, ranks, n_epochs=1): indices[0], indices[1], ranks[0], ranks[1]): self._log_connection_number(con_i) - seed_idcs = seed_idcs.compressed() - target_idcs = target_idcs.compressed() + seed_idcs = seed_idcs[seed_idcs != -1] + target_idcs = target_idcs[target_idcs != -1] con_idcs = [*seed_idcs, *target_idcs] C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] @@ -1009,8 +1009,7 @@ def spectral_connectivity_epochs_multivariate( # make sure padded indices are stored in the connectivity object if indices is not None: - # create a copy - indices = (indices_use[0].copy(), indices_use[1].copy()) + indices = tuple(np.array(indices_use)) # create a copy # get the window function, wavelets, etc for different modes (spectral_params, mt_adaptive, n_times_spectrum, @@ -1021,12 +1020,20 @@ def spectral_connectivity_epochs_multivariate( cwt_freqs=cwt_freqs, freqs=freqs, freq_mask=freq_mask) # unique signals for which we actually need to compute CSD - sig_idx = np.unique(indices_use.compressed()) + sig_idx = np.unique(np.concatenate(np.concatenate( + indices_use))) + sig_idx = sig_idx[sig_idx != -1] remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(sig_idx)} - remapped_inds = indices_use.copy() - for idx in sig_idx: - remapped_inds[indices_use == idx] = remapping[idx] - remapped_sig = np.unique(remapped_inds.compressed()) + remapping[-1] = -1 + remapped_inds = (indices_use[0].copy(), indices_use[1].copy()) + con_i = 0 + for seed, target in zip(indices_use[0], indices_use[1]): + remapped_inds[0][con_i] = np.array([ + remapping[idx] for idx in seed]) + remapped_inds[1][con_i] = np.array([ + remapping[idx] for idx in target]) + con_i += 1 + remapped_sig = [remapping[idx] for idx in sig_idx] n_signals_use = len(sig_idx) # map indices to unique indices diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index d05ec61e..54bfafa5 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -1392,11 +1392,7 @@ def test_multivar_spectral_connectivity_indices_roundtrip_io( assert isinstance(con.indices, tuple) and isinstance( read_con.indices, tuple ) - # check indices are masked - assert all([np.ma.isMA(inds) for inds in con.indices] and - [np.ma.isMA(inds) for inds in read_con.indices]) # check indices have same values - assert np.all([con_inds == read_inds for con_inds, read_inds in - zip(con.indices, read_con.indices)]) + assert np.all(np.array(con.indices) == np.array(read_con.indices)) else: assert con.indices is None and read_con.indices is None diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index d44f5be3..d0059ace 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -408,50 +408,46 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, 'indices must be specified when computing Granger ' 'causality, as all-to-all connectivity is not supported') logger.info('using all indices for multivariate connectivity') - indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], - np.arange(n_signals, dtype=int)[np.newaxis, :]) - indices_use = np.ma.masked_array(indices_use, - mask=False, fill_value=-1) + indices_use = (np.array([np.arange(n_signals, dtype=np.int32)]), + np.array([np.arange(n_signals, dtype=np.int32)])) else: logger.info('only using indices for lower-triangular matrix') indices_use = np.tril_indices(n_signals, k=-1) else: if multivariate_con: - indices_use = check_multivariate_indices(indices) # mask indices - indices_use = np.ma.concatenate([inds[np.newaxis] for inds in - indices_use]) - np.ma.set_fill_value(indices_use, -1) # else 99999 after concat. + indices_use = check_multivariate_indices(indices) # pad with -1 if any(this_method in _gc_methods for this_method in method): - for seed, target in zip(indices_use[0], indices_use[1]): - intersection = np.intersect1d(seed.compressed(), - target.compressed()) - if intersection.size > 0: + for seed, target in zip(indices[0], indices[1]): + intersection = np.intersect1d(seed, target) + if np.any(intersection != -1): # ignore padded entries raise ValueError( 'seed and target indices must not intersect when ' 'computing Granger causality') # make sure padded indices are stored in the connectivity object - # create a copy - indices = (indices_use[0].copy(), indices_use[1].copy()) + indices = tuple(np.array(indices_use)) # create a copy else: indices_use = check_indices(indices) - n_cons = len(indices_use[0]) + # create copies of indices_use for independent manipulation + source_idx = np.array(indices_use[0]) + target_idx = np.array(indices_use[1]) + n_cons = len(source_idx) # unique signals for which we actually need to compute the CSD of if multivariate_con: - signals_use = np.unique(indices_use.compressed()) + signals_use = np.unique(np.concatenate(np.concatenate(indices_use))) + signals_use = signals_use[signals_use != -1] remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(signals_use)} - remapped_inds = indices_use.copy() + remapping[-1] = -1 # multivariate functions expect seed/target remapping - for idx in signals_use: - remapped_inds[indices_use == idx] = remapping[idx] - source_idx = remapped_inds[0] - target_idx = remapped_inds[1] + con_i = 0 + for seed, target in zip(indices_use[0], indices_use[1]): + source_idx[con_i] = np.array([remapping[idx] for idx in seed]) + target_idx[con_i] = np.array([remapping[idx] for idx in target]) + con_i += 1 max_n_channels = len(indices_use[0][0]) else: # no indices remapping required for bivariate functions signals_use = np.unique(np.r_[indices_use[0], indices_use[1]]) - source_idx = indices_use[0].copy() - target_idx = indices_use[1].copy() max_n_channels = len(indices_use[0]) # check rank input and compute data ranks if necessary diff --git a/mne_connectivity/tests/test_utils.py b/mne_connectivity/tests/test_utils.py index 31ed3c79..1e5822eb 100644 --- a/mne_connectivity/tests/test_utils.py +++ b/mne_connectivity/tests/test_utils.py @@ -34,22 +34,14 @@ def test_seed_target_indices(): seeds = [[0, 1]] targets = [[2, 3], [3, 4]] indices = seed_target_multivariate_indices(seeds, targets) - match_indices = (np.array([[0, 1], [0, 1]], dtype=object), - np.array([[2, 3], [3, 4]], dtype=object)) - for type_i in range(2): - for con_i in range(len(indices[0])): - assert np.all(indices[type_i][con_i] == - match_indices[type_i][con_i]) + assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]), + np.array([[2, 3], [3, 4]]))) # ragged indices seeds = [[0, 1]] targets = [[2, 3, 4], [4]] indices = seed_target_multivariate_indices(seeds, targets) - match_indices = (np.array([[0, 1], [0, 1]], dtype=object), - np.array([[2, 3, 4], [4]], dtype=object)) - for type_i in range(2): - for con_i in range(len(indices[0])): - assert np.all(indices[type_i][con_i] == - match_indices[type_i][con_i]) + assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]), + np.array([[2, 3, 4], [4, -1, -1]]))) # test error catching # non-array-like seeds/targets with pytest.raises(TypeError, diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index 46c3afc7..a0c69bd9 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -83,23 +83,18 @@ def check_indices(indices): return indices -def check_multivariate_indices(indices, n_chans=None): - """Check indices parameter for multivariate connectivity and mask it. +def check_multivariate_indices(indices): + """Check indices parameter for multivariate connectivity and pad it. Parameters ---------- indices : tuple of array of array of int, shape (2, n_cons, variable) Tuple containing index sets. - n_chans : int | None (default None) - The number of channels in the data. Used when converting negative - indices to positive indices. Cannot be ``None`` if negative indices are - present. - Returns ------- indices : tuple of array of array of int, shape of (2, n_cons, max_n_chans) - The indices as a masked array. + The indices padded with the invalid channel index ``-1``. Notes ----- @@ -113,35 +108,26 @@ def check_multivariate_indices(indices, n_chans=None): connection must be unique. If the seed and target indices are given as lists or tuples, they will be - converted to numpy arrays. Because the number of channels can differ across + converted to numpy arrays. In case the number of channels differs across connections or between the seeds and targets for a given connection (i.e. - ragged/jagged indices), the returned array will be padded out to a 'full' - array with an invalid index (``-1``) according to the maximum number of - channels in the seed or target of any one connection. These invalid - entries are then masked and returned as numpy masked arrays. E.g. the - ragged indices of shape ``(2, n_cons, variable)``:: + ragged indices), the returned array will be padded with the invalid channel + index ``-1`` according to the maximum number of channels in the seed or + target of any one connection. E.g. the ragged indices of shape ``(2, + n_cons, variable)``:: indices = ([[0, 1], [0, 1 ]], # seeds [[2, 3], [4, 5, 6]]) # targets - would be padded to full arrays:: - - indices = ([[0, 1, -1], [0, 1, -1]], # seeds - [[2, 3, -1], [4, 5, 6]]) # targets - - to have shape ``(2, n_cons, max_n_chans)``, where ``max_n_chans = 3``. The - invalid entries are then masked:: - - indices = ([[0, 1, --], [0, 1, --]], # seeds - [[2, 3, --], [4, 5, 6]]) # targets + would be returned as:: - In case "indices" contains negative values to index channels, these will be - converted to the corresponding positive-valued index before any masking is - applied. + indices = (np.array([[0, 1, -1], [0, 1, -1]]), # seeds + np.array([[2, 3, -1], [4, 5, -1]])) # targets - More information on working with multivariate indices and handling - connections where the number of seeds and targets are not equal can be - found in the :doc:`../auto_examples/handling_ragged_arrays` example. + where the indices have been padded with ``-1`` to have shape ``(2, n_cons, + max_n_chans)``, where ``max_n_chans = 3``. More information on working with + multivariate indices and handling connections where the number of seeds and + targets are not equal can be found in the + :doc:`../auto_examples/handling_ragged_arrays` example. """ if not isinstance(indices, tuple) or len(indices) != 2: raise ValueError('indices must be a tuple of length 2') @@ -151,7 +137,6 @@ def check_multivariate_indices(indices, n_chans=None): 'have the same length') n_cons = len(indices[0]) - invalid = -1 max_n_chans = 0 for inds in ([*indices[0], *indices[1]]): @@ -164,32 +149,17 @@ def check_multivariate_indices(indices, n_chans=None): 'multivariate indices cannot contain repeated channels within ' 'a seed or target') max_n_chans = max(max_n_chans, len(inds)) - # convert negative to positive indices - if any(idx < 0 for idx in inds): - if n_chans is None: - raise ValueError( - '`n_chans` must be given if there are negative values ' - 'in `indices`') - if any(idx * -1 > n_chans for idx in inds[inds < 0]): - raise ValueError( - 'a channel index is not present in the data' - ) - inds[inds < 0] = inds[inds < 0] % n_chans # pad indices to avoid ragged arrays - padded_indices = (np.full((n_cons, max_n_chans), invalid, dtype=np.int32), - np.full((n_cons, max_n_chans), invalid, dtype=np.int32)) + padded_indices = (np.full((n_cons, max_n_chans), -1, dtype=np.int32), + np.full((n_cons, max_n_chans), -1, dtype=np.int32)) con_i = 0 for seed, target in zip(indices[0], indices[1]): padded_indices[0][con_i, :len(seed)] = seed padded_indices[1][con_i, :len(target)] = target con_i += 1 - # mask invalid indices - masked_indices = (np.ma.masked_values(padded_indices[0], invalid), - np.ma.masked_values(padded_indices[1], invalid)) - - return masked_indices + return padded_indices def seed_target_indices(seeds, targets): @@ -251,8 +221,8 @@ def seed_target_multivariate_indices(seeds, targets): Returns ------- - indices : tuple of array of array of int, shape (2, n_cons, variable) - The indices as a numpy object array. + indices : tuple of array of array of int, shape (2, n_cons, max_n_chans) + The indices padded with the invalid channel index ``-1``. Notes ----- @@ -262,8 +232,12 @@ def seed_target_multivariate_indices(seeds, targets): channels in the data. The length of indices for each connection do not need to be equal. Furthermore, all indices within a connection must be unique. - Because the number of channels per connection can vary, the indices are - stored as numpy arrays with ``dtype=object``. E.g. ``seeds`` and + ``seeds`` and ``targets`` will be expanded such that connectivity will be + computed between each set of seeds and targets. In case the number of + channels differs across connections or between the seeds and targets for a + given connection (i.e. ragged indices), the returned array will be padded + with the invalid channel index ``-1`` according to the maximum number of + channels in the seed or target of any one connection. E.g. ``seeds`` and ``targets``:: seeds = [[0]] @@ -271,15 +245,15 @@ def seed_target_multivariate_indices(seeds, targets): would be returned as:: - indices = (np.array([[0 ], [0 ]], dtype=object), # seeds - np.array([[1, 2], [3, 4, 5]], dtype=object)) # targets - - Even if the number of channels does not vary, the indices will still be - stored as object arrays for compatibility. + indices = (np.array([[0, -1, -1], [0, -1, -1]]), # seeds + np.array([[1, 2, -1], [3, 4, 5]])) # targets - More information on working with multivariate indices and handling - connections where the number of seeds and targets are not equal can be - found in the :doc:`../auto_examples/handling_ragged_arrays` example. + where the indices have been padded with ``-1`` to have shape ``(2, n_cons, + max_n_chans)``, where ``n_cons = n_unique_seeds * n_unique_targets`` and + ``max_n_chans = 3``. More information on working with multivariate indices + and handling connections where the number of seeds and targets are not + equal can be found in the :doc:`../auto_examples/handling_ragged_arrays` + example. """ array_like = (np.ndarray, list, tuple) @@ -289,6 +263,7 @@ def seed_target_multivariate_indices(seeds, targets): ): raise TypeError('`seeds` and `targets` must be array-like') + n_chans = [] for inds in [*seeds, *targets]: if not isinstance(inds, array_like): raise TypeError( @@ -296,15 +271,27 @@ def seed_target_multivariate_indices(seeds, targets): if len(inds) != len(np.unique(inds)): raise ValueError( '`seeds` and `targets` cannot contain repeated channels') + n_chans.append(len(inds)) + max_n_chans = max(n_chans) + n_cons = len(seeds) * len(targets) - indices = [[], []] - for seed in seeds: - for target in targets: - indices[0].append(np.array(seed)) - indices[1].append(np.array(target)) - - indices = (np.array(indices[0], dtype=object), - np.array(indices[1], dtype=object)) + # pad indices to avoid ragged arrays + padded_seeds = np.full((len(seeds), max_n_chans), -1, dtype=np.int32) + padded_targets = np.full((len(targets), max_n_chans), -1, dtype=np.int32) + for con_i, seed in enumerate(seeds): + padded_seeds[con_i, :len(seed)] = seed + for con_i, target in enumerate(targets): + padded_targets[con_i, :len(target)] = target + + # create final indices + indices = (np.zeros((n_cons, max_n_chans), dtype=np.int32), + np.zeros((n_cons, max_n_chans), dtype=np.int32)) + con_i = 0 + for seed in padded_seeds: + for target in padded_targets: + indices[0][con_i] = seed + indices[1][con_i] = target + con_i += 1 return indices From 55c19db20f69a97b986961fcd44d465408de5cb5 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 2 Nov 2023 15:49:20 +0100 Subject: [PATCH 32/40] Revert "bug fix missing refactoring for example" This reverts commit 6fe682fc7d2ea88266641fd5b823088d122acc6c. --- examples/handling_ragged_arrays.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/handling_ragged_arrays.py b/examples/handling_ragged_arrays.py index 2bc0dcc0..67f06a4e 100644 --- a/examples/handling_ragged_arrays.py +++ b/examples/handling_ragged_arrays.py @@ -14,7 +14,7 @@ import numpy as np -from mne_connectivity import spectral_connectivity_epochs_multivariate +from mne_connectivity import spectral_connectivity_epochs ############################################################################### # Background @@ -44,7 +44,7 @@ # targets = [[2, 3, 4], [4 ]] # # The ``indices`` parameter passed to -# :func:`~mne_connectivity.spectral_connectivity_epochs_multivariate` and +# :func:`~mne_connectivity.spectral_connectivity_epochs` and # :func:`~mne_connectivity.spectral_connectivity_time` must be a tuple of # array-likes, meaning # that the indices can be passed as a tuple of: lists; tuples; or NumPy arrays. @@ -108,7 +108,7 @@ [[2, 3, 4], [4]]) # targets # compute connectivity -con = spectral_connectivity_epochs_multivariate( +con = spectral_connectivity_epochs( data, method='mic', indices=ragged_indices, sfreq=sfreq, fmin=10, fmax=30, verbose=False) patterns = np.array(con.attrs['patterns']) From 4865e7c336b4760feb1fef06ffb86ca404ad44dc Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 2 Nov 2023 15:49:25 +0100 Subject: [PATCH 33/40] Revert "Squashed commit of the following:" This reverts commit 5bad5fb8913e66f2a46548523725afce3ff385a9. --- doc/api.rst | 1 - examples/granger_causality.py | 14 +- examples/mic_mim.py | 10 +- mne_connectivity/__init__.py | 4 +- mne_connectivity/spectral/__init__.py | 3 +- mne_connectivity/spectral/epochs.py | 1693 ++++++++++++++++- mne_connectivity/spectral/epochs_bivariate.py | 729 ------- .../spectral/epochs_multivariate.py | 1129 ----------- .../spectral/tests/test_spectral.py | 158 +- mne_connectivity/spectral/time.py | 5 +- 10 files changed, 1693 insertions(+), 2053 deletions(-) delete mode 100644 mne_connectivity/spectral/epochs_bivariate.py delete mode 100644 mne_connectivity/spectral/epochs_multivariate.py diff --git a/doc/api.rst b/doc/api.rst index 3fe85832..c91f9c02 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -47,7 +47,6 @@ on numpy array inputs. phase_slope_index vector_auto_regression spectral_connectivity_epochs - spectral_connectivity_epochs_multivariate spectral_connectivity_time Reading functions diff --git a/examples/granger_causality.py b/examples/granger_causality.py index 4129dadc..64a657db 100644 --- a/examples/granger_causality.py +++ b/examples/granger_causality.py @@ -20,7 +20,7 @@ import mne from mne.datasets.fieldtrip_cmc import data_path -from mne_connectivity import spectral_connectivity_epochs_multivariate +from mne_connectivity import spectral_connectivity_epochs ############################################################################### # Background @@ -161,10 +161,10 @@ indices_ba = (np.array([signals_b]), np.array([signals_a])) # B => A # compute Granger causality -gc_ab = spectral_connectivity_epochs_multivariate( +gc_ab = spectral_connectivity_epochs( epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=(np.array([5]), np.array([5])), gc_n_lags=20) # A => B -gc_ba = spectral_connectivity_epochs_multivariate( +gc_ba = spectral_connectivity_epochs( epochs, method=['gc'], indices=indices_ba, fmin=5, fmax=30, rank=(np.array([5]), np.array([5])), gc_n_lags=20) # B => A freqs = gc_ab.freqs @@ -262,10 +262,10 @@ # %% # compute GC on time-reversed signals -gc_tr_ab = spectral_connectivity_epochs_multivariate( +gc_tr_ab = spectral_connectivity_epochs( epochs, method=['gc_tr'], indices=indices_ab, fmin=5, fmax=30, rank=(np.array([5]), np.array([5])), gc_n_lags=20) # TR[A => B] -gc_tr_ba = spectral_connectivity_epochs_multivariate( +gc_tr_ba = spectral_connectivity_epochs( epochs, method=['gc_tr'], indices=indices_ba, fmin=5, fmax=30, rank=(np.array([5]), np.array([5])), gc_n_lags=20) # TR[B => A] @@ -317,7 +317,7 @@ # %% -gc_ab_60 = spectral_connectivity_epochs_multivariate( +gc_ab_60 = spectral_connectivity_epochs( epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=(np.array([5]), np.array([5])), gc_n_lags=60) # A => B @@ -375,7 +375,7 @@ # %% try: - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=None, gc_n_lags=20, verbose=False) # A => B print('Success!') diff --git a/examples/mic_mim.py b/examples/mic_mim.py index 62674e75..87111586 100644 --- a/examples/mic_mim.py +++ b/examples/mic_mim.py @@ -25,9 +25,7 @@ import mne from mne import EvokedArray, make_fixed_length_epochs from mne.datasets.fieldtrip_cmc import data_path -from mne_connectivity import (seed_target_indices, - spectral_connectivity_epochs, - spectral_connectivity_epochs_multivariate) +from mne_connectivity import seed_target_indices, spectral_connectivity_epochs ############################################################################### # Background @@ -89,7 +87,7 @@ target_names = [epochs.info['ch_names'][idx] for idx in targets] # multivariate imaginary part of coherency -(mic, mim) = spectral_connectivity_epochs_multivariate( +(mic, mim) = spectral_connectivity_epochs( epochs, method=['mic', 'mim'], indices=multivar_indices, fmin=5, fmax=30, rank=None) @@ -292,7 +290,7 @@ # %% indices = (np.array([[*seeds, *targets]]), np.array([[*seeds, *targets]])) -gim = spectral_connectivity_epochs_multivariate( +gim = spectral_connectivity_epochs( epochs, method='mim', indices=indices, fmin=5, fmax=30, rank=None, verbose=False) @@ -344,7 +342,7 @@ # %% -(mic_red, mim_red) = spectral_connectivity_epochs_multivariate( +(mic_red, mim_red) = spectral_connectivity_epochs( epochs, method=['mic', 'mim'], indices=multivar_indices, fmin=5, fmax=30, rank=([25], [25])) diff --git a/mne_connectivity/__init__.py b/mne_connectivity/__init__.py index 32488b33..c2f03a6c 100644 --- a/mne_connectivity/__init__.py +++ b/mne_connectivity/__init__.py @@ -15,9 +15,7 @@ from .effective import phase_slope_index from .envelope import envelope_correlation, symmetric_orth from .io import read_connectivity -from .spectral import (spectral_connectivity_time, - spectral_connectivity_epochs, - spectral_connectivity_epochs_multivariate) +from .spectral import spectral_connectivity_time, spectral_connectivity_epochs from .vector_ar import vector_auto_regression, select_order from .utils import (check_indices, check_multivariate_indices, degree, seed_target_indices, seed_target_multivariate_indices) diff --git a/mne_connectivity/spectral/__init__.py b/mne_connectivity/spectral/__init__.py index f2252db9..a0f06ef6 100644 --- a/mne_connectivity/spectral/__init__.py +++ b/mne_connectivity/spectral/__init__.py @@ -1,3 +1,2 @@ -from .epochs_bivariate import spectral_connectivity_epochs -from .epochs_multivariate import spectral_connectivity_epochs_multivariate +from .epochs import spectral_connectivity_epochs from .time import spectral_connectivity_time \ No newline at end of file diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 5742ae33..7ad551c9 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -2,6 +2,8 @@ # Denis A. Engemann # Adam Li # Thomas S. Binns +# Tien D. Nguyen +# Richard M. Köhler # # License: BSD (3-clause) @@ -9,16 +11,20 @@ import inspect import numpy as np +import scipy as sp from mne.epochs import BaseEpochs from mne.parallel import parallel_func from mne.source_estimate import _BaseSourceEstimate -from mne.time_frequency.multitaper import ( - _csd_from_mt, _mt_spectra, _psd_from_mt, _psd_from_mt_adaptive) +from mne.time_frequency.multitaper import (_csd_from_mt, + _mt_spectra, _psd_from_mt, + _psd_from_mt_adaptive) from mne.time_frequency.tfr import cwt, morlet from mne.time_frequency.multitaper import _compute_mt_params -from mne.utils import _arange_div, _check_option, _time_mask, logger, warn +from mne.utils import ( + ProgressBar, _arange_div, _check_option, _time_mask, logger, warn, verbose) -from ..base import SpectralConnectivity, SpectroTemporalConnectivity +from ..base import (SpectralConnectivity, SpectroTemporalConnectivity) +from ..utils import fill_doc, check_indices, check_multivariate_indices def _compute_freqs(n_times, sfreq, cwt_freqs, mode): @@ -57,8 +63,10 @@ def _compute_freq_mask(freqs_all, fmin, fmax, fskip): return freq_mask -def _prepare_connectivity(epoch_block, times_in, tmin, tmax, fmin, fmax, sfreq, - mode, fskip, n_bands, cwt_freqs, faverage): +def _prepare_connectivity(epoch_block, times_in, tmin, tmax, + fmin, fmax, sfreq, indices, + method, mode, fskip, n_bands, + cwt_freqs, faverage): """Check and precompute dimensions of results data.""" first_epoch = epoch_block[0] @@ -84,6 +92,43 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, fmin, fmax, sfreq, times = times_in[tmin_idx:tmax_idx] n_times = len(times) + if any(this_method in _multivariate_methods for this_method in method): + multivariate_con = True + else: + multivariate_con = False + + if indices is None: + if multivariate_con: + if any(this_method in _gc_methods for this_method in method): + raise ValueError( + 'indices must be specified when computing Granger ' + 'causality, as all-to-all connectivity is not supported') + else: + logger.info('using all indices for multivariate connectivity') + indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], + np.arange(n_signals, dtype=int)[np.newaxis, :]) + else: + logger.info('only using indices for lower-triangular matrix') + # only compute r for lower-triangular region + indices_use = np.tril_indices(n_signals, -1) + else: + if multivariate_con: + indices_use = check_multivariate_indices(indices) # pad with -1 + if any(this_method in _gc_methods for this_method in method): + for seed, target in zip(indices[0], indices[1]): + intersection = np.intersect1d(seed, target) + if np.any(intersection != -1): # ignore padded entries + raise ValueError( + 'seed and target indices must not intersect when ' + 'computing Granger causality') + else: + indices_use = check_indices(indices) + + # number of connectivities to compute + n_cons = len(indices_use[0]) + + logger.info(' computing connectivity for %d connections' + % n_cons) logger.info(' using t=%0.3fs..%0.3fs for estimation (%d points)' % (tmin_true, tmax_true, n_times)) @@ -139,9 +184,55 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, fmin, fmax, sfreq, logger.info(' connectivity scores will be averaged for ' 'each band') - return (times, n_times, times_in, n_times_in, tmin_idx, + return (n_cons, times, n_times, times_in, n_times_in, tmin_idx, tmax_idx, n_freqs, freq_mask, freqs, freqs_bands, freq_idx_bands, - n_signals, warn_times) + n_signals, indices_use, warn_times) + + +def _check_rank_input(rank, data, indices): + """Check the rank argument is appropriate and compute rank if missing.""" + sv_tol = 1e-10 # tolerance for non-zero singular val (rel to largest) + if rank is None: + rank = np.zeros((2, len(indices[0])), dtype=int) + + if isinstance(data, BaseEpochs): + data_arr = data.get_data() + else: + data_arr = data + + # XXX: Unpadding of arrays after already padding them is perhaps not so + # efficient. However, we need to remove the padded values to + # ensure only the correct channels are indexed, and having two + # versions of indices is a bit messy currently. A candidate for + # refactoring to simplify code. + + for group_i in range(2): # seeds and targets + for con_i, con_idcs in enumerate(indices[group_i]): + con_idcs = con_idcs[con_idcs != -1] # -1 is padded value + s = np.linalg.svd(data_arr[:, con_idcs], compute_uv=False) + rank[group_i][con_i] = np.min( + [np.count_nonzero(epoch >= epoch[0] * sv_tol) + for epoch in s]) + + logger.info('Estimated data ranks:') + con_i = 1 + for seed_rank, target_rank in zip(rank[0], rank[1]): + logger.info(' connection %i - seeds (%i); targets (%i)' + % (con_i, seed_rank, target_rank, )) + con_i += 1 + rank = tuple((np.array(rank[0]), np.array(rank[1]))) + + else: + for seed_idcs, target_idcs, seed_rank, target_rank in zip( + indices[0], indices[1], rank[0], rank[1]): + if not (0 < seed_rank <= len(seed_idcs) and + 0 < target_rank <= len(target_idcs)): + raise ValueError( + 'ranks for seeds and targets must be > 0 and <= the ' + 'number of channels in the seeds and targets, ' + 'respectively, for each connection') + + return rank def _assemble_spectral_params(mode, n_times, mt_adaptive, mt_bandwidth, sfreq, @@ -183,46 +274,6 @@ def _assemble_spectral_params(mode, n_times, mt_adaptive, mt_bandwidth, sfreq, return spectral_params, mt_adaptive, n_times_spectrum, n_tapers -def _compute_spectral_methods_epochs( - con_methods, epoch_block, epoch_idx, call_params, parallel, - my_spectral_connectivity_epochs, n_jobs, n_times_in, times_in, - warn_times -): - """Compute CSD/PSD for spectral_connectivity_epochs... functions.""" - # check dimensions and time scale - for this_epoch in epoch_block: - _, _, _, warn_times = _get_and_verify_data_sizes( - this_epoch, call_params["sfreq"], call_params["n_signals"], - n_times_in, times_in, warn_times=warn_times) - - if n_jobs == 1: - # no parallel processing - for this_epoch in epoch_block: - logger.info(' computing cross-spectral density for epoch %d' - % (epoch_idx + 1)) - # con methods and psd are updated inplace - _epoch_spectral_connectivity(data=this_epoch, **call_params) - epoch_idx += 1 - else: - # process epochs in parallel - logger.info( - ' computing cross-spectral density for epochs %d..%d' - % (epoch_idx + 1, epoch_idx + len(epoch_block))) - - out = parallel(my_spectral_connectivity_epochs( - data=this_epoch, **call_params) - for this_epoch in epoch_block) - # do the accumulation - for this_out in out: - for _method, parallel_method in zip(con_methods, this_out[0]): - _method.combine(parallel_method) - if call_params["psd"] is not None: - call_params["psd"] += this_out[1] - - epoch_idx += len(epoch_block) - - return epoch_idx - ######################################################################## # Various connectivity estimators @@ -242,9 +293,996 @@ def combine(self, other): def compute_con(self, con_idx, n_epochs): raise NotImplementedError('compute_con method not implemented') + +class _EpochMeanConEstBase(_AbstractConEstBase): + """Base class for methods that estimate connectivity as mean epoch-wise.""" + + patterns = None + + def __init__(self, n_cons, n_freqs, n_times): + self.n_cons = n_cons + self.n_freqs = n_freqs + self.n_times = n_times + + if n_times == 0: + self.csd_shape = (n_cons, n_freqs) + else: + self.csd_shape = (n_cons, n_freqs, n_times) + + self.con_scores = None + + def start_epoch(self): # noqa: D401 + """Called at the start of each epoch.""" + pass # for this type of con. method we don't do anything + + def combine(self, other): + """Include con. accumated for some epochs in this estimate.""" + self._acc += other._acc + + +class _EpochMeanMultivariateConEstBase(_AbstractConEstBase): + """Base class for mean epoch-wise multivar. con. estimation methods.""" + + n_steps = None + patterns = None + + def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): + self.n_signals = n_signals + self.n_cons = n_cons + self.n_freqs = n_freqs + self.n_times = n_times + self.n_jobs = n_jobs + + # include time dimension, even when unused for indexing flexibility + if n_times == 0: + self.csd_shape = (n_signals**2, n_freqs) + self.con_scores = np.zeros((n_cons, n_freqs, 1)) + else: + self.csd_shape = (n_signals**2, n_freqs, n_times) + self.con_scores = np.zeros((n_cons, n_freqs, n_times)) + + # allocate space for accumulation of CSD + self._acc = np.zeros(self.csd_shape, dtype=np.complex128) + + self._compute_n_progress_bar_steps() + + def start_epoch(self): # noqa: D401 + """Called at the start of each epoch.""" + pass # for this type of con. method we don't do anything + + def combine(self, other): + """Include con. accumulated for some epochs in this estimate.""" + self._acc += other._acc + + def accumulate(self, con_idx, csd_xy): + """Accumulate CSD for some connections.""" + self._acc[con_idx] += csd_xy + + def _compute_n_progress_bar_steps(self): + """Calculate the number of steps to include in the progress bar.""" + self.n_steps = int(np.ceil(self.n_freqs / self.n_jobs)) + + def _log_connection_number(self, con_i): + """Log the number of the connection being computed.""" + logger.info('Computing %s for connection %i of %i' + % (self.name, con_i + 1, self.n_cons, )) + + def _get_block_indices(self, block_i, limit): + """Get indices for a computation block capped by a limit.""" + indices = np.arange(block_i * self.n_jobs, (block_i + 1) * self.n_jobs) + + return indices[np.nonzero(indices < limit)] + + def reshape_csd(self): + """Reshape CSD into a matrix of times x freqs x signals x signals.""" + if self.n_times == 0: + return (np.reshape(self._acc, ( + self.n_signals, self.n_signals, self.n_freqs, 1) + ).transpose(3, 2, 0, 1)) + + return (np.reshape(self._acc, ( + self.n_signals, self.n_signals, self.n_freqs, self.n_times) + ).transpose(3, 2, 0, 1)) + + +class _CohEstBase(_EpochMeanConEstBase): + """Base Estimator for Coherence, Coherency, Imag. Coherence.""" + + accumulate_psd = True + + def __init__(self, n_cons, n_freqs, n_times): + super(_CohEstBase, self).__init__(n_cons, n_freqs, n_times) + + # allocate space for accumulation of CSD + self._acc = np.zeros(self.csd_shape, dtype=np.complex128) + + def accumulate(self, con_idx, csd_xy): + """Accumulate CSD for some connections.""" + self._acc[con_idx] += csd_xy + + +class _CohEst(_CohEstBase): + """Coherence Estimator.""" + + name = 'Coherence' + + def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + csd_mean = self._acc[con_idx] / n_epochs + self.con_scores[con_idx] = np.abs(csd_mean) / np.sqrt(psd_xx * psd_yy) + + +class _CohyEst(_CohEstBase): + """Coherency Estimator.""" + + name = 'Coherency' + + def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape, + dtype=np.complex128) + csd_mean = self._acc[con_idx] / n_epochs + self.con_scores[con_idx] = csd_mean / np.sqrt(psd_xx * psd_yy) + + +class _ImCohEst(_CohEstBase): + """Imaginary Coherence Estimator.""" + + name = 'Imaginary Coherence' + + def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + csd_mean = self._acc[con_idx] / n_epochs + self.con_scores[con_idx] = np.imag(csd_mean) / np.sqrt(psd_xx * psd_yy) + + +class _MultivariateCohEstBase(_EpochMeanMultivariateConEstBase): + """Base estimator for multivariate imag. part of coherency methods. + + See Ewald et al. (2012). NeuroImage. DOI: 10.1016/j.neuroimage.2011.11.084 + for equation references. + """ + + name = None + accumulate_psd = False + + def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): + super(_MultivariateCohEstBase, self).__init__( + n_signals, n_cons, n_freqs, n_times, n_jobs) + + def compute_con(self, indices, ranks, n_epochs=1): + """Compute multivariate imag. part of coherency between signals.""" + assert self.name in ['MIC', 'MIM'], ( + 'the class name is not recognised, please contact the ' + 'mne-connectivity developers') + + csd = self.reshape_csd() / n_epochs + n_times = csd.shape[0] + times = np.arange(n_times) + freqs = np.arange(self.n_freqs) + + if self.name == 'MIC': + self.patterns = np.full( + (2, self.n_cons, indices[0].shape[1], self.n_freqs, n_times), + np.nan) + + con_i = 0 + for seed_idcs, target_idcs, seed_rank, target_rank in zip( + indices[0], indices[1], ranks[0], ranks[1]): + self._log_connection_number(con_i) + + seed_idcs = seed_idcs[seed_idcs != -1] + target_idcs = target_idcs[target_idcs != -1] + con_idcs = [*seed_idcs, *target_idcs] + + C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] + + # Eqs. 32 & 33 + C_bar, U_bar_aa, U_bar_bb = self._csd_svd( + C, seed_idcs, seed_rank, target_rank) + + # Eqs. 3 & 4 + E = self._compute_e(C_bar, n_seeds=U_bar_aa.shape[3]) + + if self.name == 'MIC': + self._compute_mic(E, C, seed_idcs, target_idcs, n_times, + U_bar_aa, U_bar_bb, con_i) + else: + self._compute_mim(E, seed_idcs, target_idcs, con_i) + + con_i += 1 + + self.reshape_results() + + def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): + """Dimensionality reduction of CSD with SVD.""" + n_times = csd.shape[0] + n_seeds = len(seed_idcs) + n_targets = csd.shape[3] - n_seeds + + C_aa = csd[..., :n_seeds, :n_seeds] + C_ab = csd[..., :n_seeds, n_seeds:] + C_bb = csd[..., n_seeds:, n_seeds:] + C_ba = csd[..., n_seeds:, :n_seeds] + + # Eq. 32 + if seed_rank != n_seeds: + U_aa = np.linalg.svd(np.real(C_aa), full_matrices=False)[0] + U_bar_aa = U_aa[..., :seed_rank] + else: + U_bar_aa = np.broadcast_to( + np.identity(n_seeds), + (n_times, self.n_freqs) + (n_seeds, n_seeds)) + + if target_rank != n_targets: + U_bb = np.linalg.svd(np.real(C_bb), full_matrices=False)[0] + U_bar_bb = U_bb[..., :target_rank] + else: + U_bar_bb = np.broadcast_to( + np.identity(n_targets), + (n_times, self.n_freqs) + (n_targets, n_targets)) + + # Eq. 33 + C_bar_aa = np.matmul( + U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_aa, U_bar_aa)) + C_bar_ab = np.matmul( + U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_ab, U_bar_bb)) + C_bar_bb = np.matmul( + U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_bb, U_bar_bb)) + C_bar_ba = np.matmul( + U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_ba, U_bar_aa)) + C_bar = np.append(np.append(C_bar_aa, C_bar_ab, axis=3), + np.append(C_bar_ba, C_bar_bb, axis=3), axis=2) + + return C_bar, U_bar_aa, U_bar_bb + + def _compute_e(self, csd, n_seeds): + """Compute E from the CSD.""" + C_r = np.real(csd) + + parallel, parallel_compute_t, _ = parallel_func( + _mic_mim_compute_t, self.n_jobs, verbose=False) + + # imag. part of T filled when data is rank-deficient + T = np.zeros(csd.shape, dtype=np.complex128) + for block_i in ProgressBar( + range(self.n_steps), mesg="frequency blocks"): + freqs = self._get_block_indices(block_i, self.n_freqs) + T[:, freqs] = np.array(parallel(parallel_compute_t( + C_r[:, f], T[:, f], n_seeds) for f in freqs) + ).transpose(1, 0, 2, 3) + + if not np.isreal(T).all() or not np.isfinite(T).all(): + raise RuntimeError( + 'the transformation matrix of the data must be real-valued ' + 'and contain no NaN or infinity values; check that you are ' + 'using full rank data or specify an appropriate rank for the ' + 'seeds and targets that is less than or equal to their ranks') + T = np.real(T) # make T real if check passes + + # Eq. 4 + D = np.matmul(T, np.matmul(csd, T)) + + # E as imag. part of D between seeds and targets + return np.imag(D[..., :n_seeds, n_seeds:]) + + def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, + U_bar_bb, con_i): + """Compute MIC and the associated spatial patterns.""" + n_seeds = len(seed_idcs) + n_targets = len(target_idcs) + times = np.arange(n_times) + freqs = np.arange(self.n_freqs) + + # Eigendecomp. to find spatial filters for seeds and targets + w_seeds, V_seeds = np.linalg.eigh( + np.matmul(E, E.transpose(0, 1, 3, 2))) + w_targets, V_targets = np.linalg.eigh( + np.matmul(E.transpose(0, 1, 3, 2), E)) + if ( + len(seed_idcs) == len(target_idcs) and + np.all(np.sort(seed_idcs) == np.sort(target_idcs)) + ): + # strange edge-case where the eigenvectors returned should be a set + # of identity matrices with one rotated by 90 degrees, but are + # instead identical (i.e. are not rotated versions of one another). + # This leads to the case where the spatial filters are incorrectly + # applied, resulting in connectivity estimates of ~0 when they + # should be perfectly correlated ~1. Accordingly, we manually + # create a set of rotated identity matrices to use as the filters. + create_filter = False + stop = False + while not create_filter and not stop: + for time_i in range(n_times): + for freq_i in range(self.n_freqs): + if np.all(V_seeds[time_i, freq_i] == + V_targets[time_i, freq_i]): + create_filter = True + break + stop = True + if create_filter: + n_chans = E.shape[2] + eye_4d = np.zeros_like(V_seeds) + eye_4d[:, :, np.arange(n_chans), np.arange(n_chans)] = 1 + V_seeds = eye_4d + V_targets = np.rot90(eye_4d, axes=(2, 3)) + + # Spatial filters with largest eigval. for seeds and targets + alpha = V_seeds[times[:, None], freqs, :, w_seeds.argmax(axis=2)] + beta = V_targets[times[:, None], freqs, :, w_targets.argmax(axis=2)] + + # Eq. 46 (seed spatial patterns) + self.patterns[0, con_i, :n_seeds] = (np.matmul( + np.real(C[..., :n_seeds, :n_seeds]), + np.matmul(U_bar_aa, np.expand_dims(alpha, axis=3))))[..., 0].T + + # Eq. 47 (target spatial patterns) + self.patterns[1, con_i, :n_targets] = (np.matmul( + np.real(C[..., n_seeds:, n_seeds:]), + np.matmul(U_bar_bb, np.expand_dims(beta, axis=3))))[..., 0].T + + # Eq. 7 + self.con_scores[con_i] = (np.einsum( + 'ijk,ijk->ij', alpha, np.matmul(E, np.expand_dims( + beta, axis=3))[..., 0] + ) / np.linalg.norm(alpha, axis=2) * np.linalg.norm(beta, axis=2)).T + + def _compute_mim(self, E, seed_idcs, target_idcs, con_i): + """Compute MIM (a.k.a. GIM if seeds == targets).""" + # Eq. 14 + self.con_scores[con_i] = np.matmul( + E, E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T + + # Eq. 15 + if ( + len(seed_idcs) == len(target_idcs) and + np.all(np.sort(seed_idcs) == np.sort(target_idcs)) + ): + self.con_scores[con_i] *= 0.5 + + def reshape_results(self): + """Remove time dimension from results, if necessary.""" + if self.n_times == 0: + self.con_scores = self.con_scores[..., 0] + if self.patterns is not None: + self.patterns = self.patterns[..., 0] + + +def _mic_mim_compute_t(C, T, n_seeds): + """Compute T for a single frequency (used for MIC and MIM).""" + for time_i in range(C.shape[0]): + T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power( + C[time_i, :n_seeds, :n_seeds], -0.5 + ) + T[time_i, n_seeds:, n_seeds:] = sp.linalg.fractional_matrix_power( + C[time_i, n_seeds:, n_seeds:], -0.5 + ) + + return T + + +class _MICEst(_MultivariateCohEstBase): + """Multivariate imaginary part of coherency (MIC) estimator.""" + + name = "MIC" + + +class _MIMEst(_MultivariateCohEstBase): + """Multivariate interaction measure (MIM) estimator.""" + + name = "MIM" + + +class _PLVEst(_EpochMeanConEstBase): + """PLV Estimator.""" + + name = 'PLV' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_PLVEst, self).__init__(n_cons, n_freqs, n_times) + + # allocate accumulator + self._acc = np.zeros(self.csd_shape, dtype=np.complex128) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + self._acc[con_idx] += csd_xy / np.abs(csd_xy) + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + plv = np.abs(self._acc / n_epochs) + self.con_scores[con_idx] = plv + + +class _ciPLVEst(_EpochMeanConEstBase): + """corrected imaginary PLV Estimator.""" + + name = 'ciPLV' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_ciPLVEst, self).__init__(n_cons, n_freqs, n_times) + + # allocate accumulator + self._acc = np.zeros(self.csd_shape, dtype=np.complex128) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + self._acc[con_idx] += csd_xy / np.abs(csd_xy) + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + imag_plv = np.abs(np.imag(self._acc)) / n_epochs + real_plv = np.real(self._acc) / n_epochs + real_plv = np.clip(real_plv, -1, 1) # bounded from -1 to 1 + mask = (np.abs(real_plv) == 1) # avoid division by 0 + real_plv[mask] = 0 + corrected_imag_plv = imag_plv / np.sqrt(1 - real_plv ** 2) + self.con_scores[con_idx] = corrected_imag_plv + + +class _PLIEst(_EpochMeanConEstBase): + """PLI Estimator.""" + + name = 'PLI' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_PLIEst, self).__init__(n_cons, n_freqs, n_times) + + # allocate accumulator + self._acc = np.zeros(self.csd_shape) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + self._acc[con_idx] += np.sign(np.imag(csd_xy)) + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + pli_mean = self._acc[con_idx] / n_epochs + self.con_scores[con_idx] = np.abs(pli_mean) + + +class _PLIUnbiasedEst(_PLIEst): + """Unbiased PLI Square Estimator.""" + + name = 'Unbiased PLI Square' + accumulate_psd = False + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + pli_mean = self._acc[con_idx] / n_epochs + + # See Vinck paper Eq. (30) + con = (n_epochs * pli_mean ** 2 - 1) / (n_epochs - 1) + + self.con_scores[con_idx] = con + + +class _DPLIEst(_EpochMeanConEstBase): + """DPLI Estimator.""" + + name = 'DPLI' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_DPLIEst, self).__init__(n_cons, n_freqs, n_times) + + # allocate accumulator + self._acc = np.zeros(self.csd_shape) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + self._acc[con_idx] += np.heaviside(np.imag(csd_xy), 0.5) + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + + con = self._acc[con_idx] / n_epochs + + self.con_scores[con_idx] = con + + +class _WPLIEst(_EpochMeanConEstBase): + """WPLI Estimator.""" + + name = 'WPLI' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_WPLIEst, self).__init__(n_cons, n_freqs, n_times) + + # store both imag(csd) and abs(imag(csd)) + acc_shape = (2,) + self.csd_shape + self._acc = np.zeros(acc_shape) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + im_csd = np.imag(csd_xy) + self._acc[0, con_idx] += im_csd + self._acc[1, con_idx] += np.abs(im_csd) + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + + num = np.abs(self._acc[0, con_idx]) + denom = self._acc[1, con_idx] + + # handle zeros in denominator + z_denom = np.where(denom == 0.) + denom[z_denom] = 1. + + con = num / denom + + # where we had zeros in denominator, we set con to zero + con[z_denom] = 0. + + self.con_scores[con_idx] = con + + +class _WPLIDebiasedEst(_EpochMeanConEstBase): + """Debiased WPLI Square Estimator.""" + + name = 'Debiased WPLI Square' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_WPLIDebiasedEst, self).__init__(n_cons, n_freqs, n_times) + # store imag(csd), abs(imag(csd)), imag(csd)^2 + acc_shape = (3,) + self.csd_shape + self._acc = np.zeros(acc_shape) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + im_csd = np.imag(csd_xy) + self._acc[0, con_idx] += im_csd + self._acc[1, con_idx] += np.abs(im_csd) + self._acc[2, con_idx] += im_csd ** 2 + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + + # note: we use the trick from fieldtrip to compute the + # the estimate over all pairwise epoch combinations + sum_im_csd = self._acc[0, con_idx] + sum_abs_im_csd = self._acc[1, con_idx] + sum_sq_im_csd = self._acc[2, con_idx] + + denom = sum_abs_im_csd ** 2 - sum_sq_im_csd + + # handle zeros in denominator + z_denom = np.where(denom == 0.) + denom[z_denom] = 1. + + con = (sum_im_csd ** 2 - sum_sq_im_csd) / denom + + # where we had zeros in denominator, we set con to zero + con[z_denom] = 0. + + self.con_scores[con_idx] = con + + +class _PPCEst(_EpochMeanConEstBase): + """Pairwise Phase Consistency (PPC) Estimator.""" + + name = 'PPC' + accumulate_psd = False + + def __init__(self, n_cons, n_freqs, n_times): + super(_PPCEst, self).__init__(n_cons, n_freqs, n_times) + + # store csd / abs(csd) + self._acc = np.zeros(self.csd_shape, dtype=np.complex128) + + def accumulate(self, con_idx, csd_xy): + """Accumulate some connections.""" + denom = np.abs(csd_xy) + z_denom = np.where(denom == 0.) + denom[z_denom] = 1. + this_acc = csd_xy / denom + this_acc[z_denom] = 0. # handle division by zero + + self._acc[con_idx] += this_acc + + def compute_con(self, con_idx, n_epochs): + """Compute final con. score for some connections.""" + if self.con_scores is None: + self.con_scores = np.zeros(self.csd_shape) + + # note: we use the trick from fieldtrip to compute the + # the estimate over all pairwise epoch combinations + con = ((self._acc[con_idx] * np.conj(self._acc[con_idx]) - n_epochs) / + (n_epochs * (n_epochs - 1.))) + + self.con_scores[con_idx] = np.real(con) + + +class _GCEstBase(_EpochMeanMultivariateConEstBase): + """Base multivariate state-space Granger causality estimator.""" + + accumulate_psd = False + + def __init__(self, n_signals, n_cons, n_freqs, n_times, n_lags, n_jobs=1): + super(_GCEstBase, self).__init__( + n_signals, n_cons, n_freqs, n_times, n_jobs) + + self.freq_res = (self.n_freqs - 1) * 2 + if n_lags >= self.freq_res: + raise ValueError( + 'the number of lags (%i) must be less than double the ' + 'frequency resolution (%i)' % (n_lags, self.freq_res, )) + self.n_lags = n_lags + + def compute_con(self, indices, ranks, n_epochs=1): + """Compute multivariate state-space Granger causality.""" + assert self.name in ['GC', 'GC time-reversed'], ( + 'the class name is not recognised, please contact the ' + 'mne-connectivity developers') + + csd = self.reshape_csd() / n_epochs + + n_times = csd.shape[0] + times = np.arange(n_times) + freqs = np.arange(self.n_freqs) + + con_i = 0 + for seed_idcs, target_idcs, seed_rank, target_rank in zip( + indices[0], indices[1], ranks[0], ranks[1]): + self._log_connection_number(con_i) + + seed_idcs = seed_idcs[seed_idcs != -1] + target_idcs = target_idcs[target_idcs != -1] + con_idcs = [*seed_idcs, *target_idcs] + + C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] + + C_bar = self._csd_svd(C, seed_idcs, seed_rank, target_rank) + n_signals = seed_rank + target_rank + con_seeds = np.arange(seed_rank) + con_targets = np.arange(target_rank) + seed_rank + + autocov = self._compute_autocov(C_bar) + if self.name == "GC time-reversed": + autocov = autocov.transpose(0, 1, 3, 2) + + A_f, V = self._autocov_to_full_var(autocov) + A_f_3d = np.reshape( + A_f, (n_times, n_signals, n_signals * self.n_lags), + order="F") + A, K = self._full_var_to_iss(A_f_3d) + + self.con_scores[con_i] = self._iss_to_ugc( + A, A_f_3d, K, V, con_seeds, con_targets) + + con_i += 1 + + self.reshape_results() + + def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): + """Dimensionality reduction of CSD with SVD on the covariance.""" + # sum over times and epochs to get cov. from CSD + cov = csd.sum(axis=(0, 1)) + + n_seeds = len(seed_idcs) + n_targets = csd.shape[3] - n_seeds + + cov_aa = cov[:n_seeds, :n_seeds] + cov_bb = cov[n_seeds:, n_seeds:] + + if seed_rank != n_seeds: + U_aa = np.linalg.svd(np.real(cov_aa), full_matrices=False)[0] + U_bar_aa = U_aa[:, :seed_rank] + else: + U_bar_aa = np.identity(n_seeds) + + if target_rank != n_targets: + U_bb = np.linalg.svd(np.real(cov_bb), full_matrices=False)[0] + U_bar_bb = U_bb[:, :target_rank] + else: + U_bar_bb = np.identity(n_targets) + + C_aa = csd[..., :n_seeds, :n_seeds] + C_ab = csd[..., :n_seeds, n_seeds:] + C_bb = csd[..., n_seeds:, n_seeds:] + C_ba = csd[..., n_seeds:, :n_seeds] + + C_bar_aa = np.matmul( + U_bar_aa.transpose(1, 0), np.matmul(C_aa, U_bar_aa)) + C_bar_ab = np.matmul( + U_bar_aa.transpose(1, 0), np.matmul(C_ab, U_bar_bb)) + C_bar_bb = np.matmul( + U_bar_bb.transpose(1, 0), np.matmul(C_bb, U_bar_bb)) + C_bar_ba = np.matmul( + U_bar_bb.transpose(1, 0), np.matmul(C_ba, U_bar_aa)) + C_bar = np.append(np.append(C_bar_aa, C_bar_ab, axis=3), + np.append(C_bar_ba, C_bar_bb, axis=3), axis=2) + + return C_bar + + def _compute_autocov(self, csd): + """Compute autocovariance from the CSD.""" + n_times = csd.shape[0] + n_signals = csd.shape[2] + + circular_shifted_csd = np.concatenate( + [np.flip(np.conj(csd[:, 1:]), axis=1), csd[:, :-1]], axis=1) + ifft_shifted_csd = self._block_ifft( + circular_shifted_csd, self.freq_res) + lags_ifft_shifted_csd = np.reshape( + ifft_shifted_csd[:, :self.n_lags + 1], + (n_times, self.n_lags + 1, n_signals ** 2), order="F") + + signs = np.repeat([1], self.n_lags + 1).tolist() + signs[1::2] = [x * -1 for x in signs[1::2]] + sign_matrix = np.repeat( + np.tile(np.array(signs), (n_signals ** 2, 1))[np.newaxis], + n_times, axis=0).transpose(0, 2, 1) + + return np.real(np.reshape( + sign_matrix * lags_ifft_shifted_csd, + (n_times, self.n_lags + 1, n_signals, n_signals), order="F")) + + def _block_ifft(self, csd, n_points): + """Compute block iFFT with n points.""" + shape = csd.shape + csd_3d = np.reshape( + csd, (shape[0], shape[1], shape[2] * shape[3]), order="F") + + csd_ifft = np.fft.ifft(csd_3d, n=n_points, axis=1) + + return np.reshape(csd_ifft, shape, order="F") + + def _autocov_to_full_var(self, autocov): + """Compute full VAR model using Whittle's LWR recursion.""" + if np.any(np.linalg.det(autocov) == 0): + raise RuntimeError( + 'the autocovariance matrix is singular; check if your data is ' + 'rank deficient and specify an appropriate rank argument <= ' + 'the rank of the seeds and targets') + + A_f, V = self._whittle_lwr_recursion(autocov) + + if not np.isfinite(A_f).all(): + raise RuntimeError('at least one VAR model coefficient is ' + 'infinite or NaN; check the data you are using') + + try: + np.linalg.cholesky(V) + except np.linalg.LinAlgError as np_error: + raise RuntimeError( + 'the covariance matrix of the residuals is not ' + 'positive-definite; check the singular values of your data ' + 'and specify an appropriate rank argument <= the rank of the ' + 'seeds and targets') from np_error + + return A_f, V + + def _whittle_lwr_recursion(self, G): + """Solve Yule-Walker eqs. for full VAR params. with LWR recursion. + + See: Whittle P., 1963. Biometrika, DOI: 10.1093/biomet/50.1-2.129 + """ + # Initialise recursion + n = G.shape[2] # number of signals + q = G.shape[1] - 1 # number of lags + t = G.shape[0] # number of times + qn = n * q + + cov = G[:, 0, :, :] # covariance + G_f = np.reshape( + G[:, 1:, :, :].transpose(0, 3, 1, 2), (t, qn, n), + order="F") # forward autocov + G_b = np.reshape( + np.flip(G[:, 1:, :, :], 1).transpose(0, 3, 2, 1), (t, n, qn), + order="F").transpose(0, 2, 1) # backward autocov + + A_f = np.zeros((t, n, qn)) # forward coefficients + A_b = np.zeros((t, n, qn)) # backward coefficients + + k = 1 # model order + r = q - k + k_f = np.arange(k * n) # forward indices + k_b = np.arange(r * n, qn) # backward indices + + try: + A_f[:, :, k_f] = np.linalg.solve( + cov, G_b[:, k_b, :].transpose(0, 2, 1)).transpose(0, 2, 1) + A_b[:, :, k_b] = np.linalg.solve( + cov, G_f[:, k_f, :].transpose(0, 2, 1)).transpose(0, 2, 1) + + # Perform recursion + for k in np.arange(2, q + 1): + var_A = (G_b[:, (r - 1) * n: r * n, :] - + np.matmul(A_f[:, :, k_f], G_b[:, k_b, :])) + var_B = cov - np.matmul(A_b[:, :, k_b], G_b[:, k_b, :]) + AA_f = np.linalg.solve( + var_B, var_A.transpose(0, 2, 1)).transpose(0, 2, 1) + + var_A = (G_f[:, (k - 1) * n: k * n, :] - + np.matmul(A_b[:, :, k_b], G_f[:, k_f, :])) + var_B = cov - np.matmul(A_f[:, :, k_f], G_f[:, k_f, :]) + AA_b = np.linalg.solve( + var_B, var_A.transpose(0, 2, 1)).transpose(0, 2, 1) + + A_f_previous = A_f[:, :, k_f] + A_b_previous = A_b[:, :, k_b] + + r = q - k + k_f = np.arange(k * n) + k_b = np.arange(r * n, qn) + + A_f[:, :, k_f] = np.dstack( + (A_f_previous - np.matmul(AA_f, A_b_previous), AA_f)) + A_b[:, :, k_b] = np.dstack( + (AA_b, A_b_previous - np.matmul(AA_b, A_f_previous))) + except np.linalg.LinAlgError as np_error: + raise RuntimeError( + 'the autocovariance matrix is singular; check if your data is ' + 'rank deficient and specify an appropriate rank argument <= ' + 'the rank of the seeds and targets') from np_error + + V = cov - np.matmul(A_f, G_f) + A_f = np.reshape(A_f, (t, n, n, q), order="F") + + return A_f, V + + def _full_var_to_iss(self, A_f): + """Compute innovations-form parameters for a state-space model. + + Parameters computed from a full VAR model using Aoki's method. For a + non-moving-average full VAR model, the state-space parameter C + (observation matrix) is identical to AF of the VAR model. + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101. + """ + t = A_f.shape[0] + m = A_f.shape[1] # number of signals + p = A_f.shape[2] // m # number of autoregressive lags + + I_p = np.dstack(t * [np.eye(m * p)]).transpose(2, 0, 1) + A = np.hstack((A_f, I_p[:, : (m * p - m), :])) # state transition + # matrix + K = np.hstack(( + np.dstack(t * [np.eye(m)]).transpose(2, 0, 1), + np.zeros((t, (m * (p - 1)), m)))) # Kalman gain matrix + + return A, K + + def _iss_to_ugc(self, A, C, K, V, seeds, targets): + """Compute unconditional GC from innovations-form state-space params. + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101. + """ + times = np.arange(A.shape[0]) + freqs = np.arange(self.n_freqs) + z = np.exp(-1j * np.pi * np.linspace(0, 1, self.n_freqs)) # points + # on a unit circle in the complex plane, one for each frequency + + H = self._iss_to_tf(A, C, K, z) # spectral transfer function + V_22_1 = np.linalg.cholesky(self._partial_covar(V, seeds, targets)) + HV = np.matmul(H, np.linalg.cholesky(V)) + S = np.matmul(HV, HV.conj().transpose(0, 1, 3, 2)) # Eq. 6 + S_11 = S[np.ix_(freqs, times, targets, targets)] + HV_12 = np.matmul(H[np.ix_(freqs, times, targets, seeds)], V_22_1) + HVH = np.matmul(HV_12, HV_12.conj().transpose(0, 1, 3, 2)) + + # Eq. 11 + return np.real( + np.log(np.linalg.det(S_11)) - np.log(np.linalg.det(S_11 - HVH))) + + def _iss_to_tf(self, A, C, K, z): + """Compute transfer function for innovations-form state-space params. + + In the frequency domain, the back-shift operator, z, is a vector of + points on a unit circle in the complex plane. z = e^-iw, where -pi < w + <= pi. + + A note on efficiency: solving over the 4D time-freq. tensor is slower + than looping over times and freqs when n_times and n_freqs high, and + when n_times and n_freqs low, looping over times and freqs very fast + anyway (plus tensor solving doesn't allow for parallelisation). + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101. + """ + t = A.shape[0] + h = self.n_freqs + n = C.shape[1] + m = A.shape[1] + I_n = np.eye(n) + I_m = np.eye(m) + H = np.zeros((h, t, n, n), dtype=np.complex128) + + parallel, parallel_compute_H, _ = parallel_func( + _gc_compute_H, self.n_jobs, verbose=False + ) + H = np.zeros((h, t, n, n), dtype=np.complex128) + for block_i in ProgressBar( + range(self.n_steps), mesg="frequency blocks" + ): + freqs = self._get_block_indices(block_i, self.n_freqs) + H[freqs] = parallel( + parallel_compute_H(A, C, K, z[k], I_n, I_m) for k in freqs) + + return H + + def _partial_covar(self, V, seeds, targets): + """Compute partial covariance of a matrix. + + Given a covariance matrix V, the partial covariance matrix of V between + indices i and j, given k (V_ij|k), is equivalent to V_ij - V_ik * + V_kk^-1 * V_kj. In this case, i and j are seeds, and k are targets. + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101. + """ + times = np.arange(V.shape[0]) + W = np.linalg.solve( + np.linalg.cholesky(V[np.ix_(times, targets, targets)]), + V[np.ix_(times, targets, seeds)], + ) + W = np.matmul(W.transpose(0, 2, 1), W) + + return V[np.ix_(times, seeds, seeds)] - W + + def reshape_results(self): + """Remove time dimension from con. scores, if necessary.""" + if self.n_times == 0: + self.con_scores = self.con_scores[:, :, 0] + + +def _gc_compute_H(A, C, K, z_k, I_n, I_m): + """Compute transfer function for innovations-form state-space params. + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101, Eq. 4. + """ + from scipy import linalg # XXX: is this necessary??? + H = np.zeros((A.shape[0], C.shape[1], C.shape[1]), dtype=np.complex128) + for t in range(A.shape[0]): + H[t] = I_n + np.matmul( + C[t], linalg.lu_solve(linalg.lu_factor(z_k * I_m - A[t]), K[t])) + + return H + + +class _GCEst(_GCEstBase): + """[seeds -> targets] state-space GC estimator.""" + + name = "GC" + + +class _GCTREst(_GCEstBase): + """time-reversed[seeds -> targets] state-space GC estimator.""" + + name = "GC time-reversed" + ############################################################################### +_multivariate_methods = ['mic', 'mim', 'gc', 'gc_tr'] _gc_methods = ['gc', 'gc_tr'] @@ -254,9 +1292,9 @@ def _epoch_spectral_connectivity(data, sig_idx, tmin_idx, tmax_idx, sfreq, block_size, psd, accumulate_psd, con_method_types, con_methods, n_signals, n_signals_use, n_times, gc_n_lags, - multivariate_con, accumulate_inplace=True): + accumulate_inplace=True): """Estimate connectivity for one epoch (see spectral_connectivity).""" - if multivariate_con: + if any(this_method in _multivariate_methods for this_method in method): n_con_signals = n_signals_use ** 2 else: n_con_signals = n_cons @@ -273,7 +1311,8 @@ def _epoch_spectral_connectivity(data, sig_idx, tmin_idx, tmax_idx, sfreq, con_methods = [] for mtype in con_method_types: method_params = list(inspect.signature(mtype).parameters) - if multivariate_con: + if "n_signals" in method_params: + # if it's a multivariate connectivity method if "n_lags" in method_params: # if it's a Granger causality method con_methods.append( @@ -462,12 +1501,22 @@ def _get_and_verify_data_sizes(data, sfreq, n_signals=None, n_times=None, return n_signals, n_times, times, warn_times -def _check_estimators(method, con_method_map): +# map names to estimator types +_CON_METHOD_MAP = {'coh': _CohEst, 'cohy': _CohyEst, 'imcoh': _ImCohEst, + 'plv': _PLVEst, 'ciplv': _ciPLVEst, 'ppc': _PPCEst, + 'pli': _PLIEst, 'pli2_unbiased': _PLIUnbiasedEst, + 'dpli': _DPLIEst, 'wpli': _WPLIEst, + 'wpli2_debiased': _WPLIDebiasedEst, 'mic': _MICEst, + 'mim': _MIMEst, 'gc': _GCEst, 'gc_tr': _GCTREst} + + +def _check_estimators(method): """Check construction of connectivity estimators.""" + n_methods = len(method) con_method_types = list() for this_method in method: - if this_method in con_method_map: - con_method_types.append(con_method_map[this_method]) + if this_method in _CON_METHOD_MAP: + con_method_types.append(_CON_METHOD_MAP[this_method]) elif isinstance(this_method, str): raise ValueError('%s is not a valid connectivity method' % this_method) @@ -483,18 +1532,290 @@ def _check_estimators(method, con_method_map): accumulate_psd = any( this_method.accumulate_psd for this_method in con_method_types) - return con_method_types, accumulate_psd + return con_method_types, n_methods, accumulate_psd + + +@ verbose +@ fill_doc +def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, + sfreq=None, + mode='multitaper', fmin=None, fmax=np.inf, + fskip=0, faverage=False, tmin=None, tmax=None, + mt_bandwidth=None, mt_adaptive=False, + mt_low_bias=True, cwt_freqs=None, + cwt_n_cycles=7, gc_n_lags=40, rank=None, + block_size=1000, n_jobs=1, verbose=None): + r"""Compute frequency- and time-frequency-domain connectivity measures. + + The connectivity method(s) are specified using the "method" parameter. + All methods are based on estimates of the cross- and power spectral + densities (CSD/PSD) Sxy and Sxx, Syy. + + Parameters + ---------- + data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs + The data from which to compute connectivity. Note that it is also + possible to combine multiple signals by providing a list of tuples, + e.g., data = [(arr_0, stc_0), (arr_1, stc_1), (arr_2, stc_2)], + corresponds to 3 epochs, and arr_* could be an array with the same + number of time points as stc_*. The array-like object can also + be a list/generator of array, shape =(n_signals, n_times), + or a list/generator of SourceEstimate or VolSourceEstimate objects. + %(names)s + method : str | list of str + Connectivity measure(s) to compute. These can be ``['coh', 'cohy', + 'imcoh', 'mic', 'mim', 'plv', 'ciplv', 'ppc', 'pli', 'dpli', 'wpli', + 'wpli2_debiased', 'gc', 'gc_tr']``. Multivariate methods (``['mic', + 'mim', 'gc', 'gc_tr]``) cannot be called with the other methods. + indices : tuple of array | None + Two arrays with indices of connections for which to compute + connectivity. If a bivariate method is called, each array for the seeds + and targets should contain the channel indices for each bivariate + connection. If a multivariate method is called, each array for the + seeds and targets should consist of nested arrays containing + the channel indices for each multivariate connection. If ``None``, + connections between all channels are computed, unless a Granger + causality method is called, in which case an error is raised. + sfreq : float + The sampling frequency. Required if data is not + :class:`Epochs `. + mode : str + Spectrum estimation mode can be either: 'multitaper', 'fourier', or + 'cwt_morlet'. + fmin : float | tuple of float + The lower frequency of interest. Multiple bands are defined using + a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. + fmax : float | tuple of float + The upper frequency of interest. Multiple bands are dedined using + a tuple, e.g. (13., 30.) for two band with 13Hz and 30Hz upper freq. + fskip : int + Omit every "(fskip + 1)-th" frequency bin to decimate in frequency + domain. + faverage : bool + Average connectivity scores for each frequency band. If True, + the output freqs will be a list with arrays of the frequencies + that were averaged. + tmin : float | None + Time to start connectivity estimation. Note: when "data" is an array, + the first sample is assumed to be at time 0. For other types + (Epochs, etc.), the time information contained in the object is used + to compute the time indices. + tmax : float | None + Time to end connectivity estimation. Note: when "data" is an array, + the first sample is assumed to be at time 0. For other types + (Epochs, etc.), the time information contained in the object is used + to compute the time indices. + mt_bandwidth : float | None + The bandwidth of the multitaper windowing function in Hz. + Only used in 'multitaper' mode. + mt_adaptive : bool + Use adaptive weights to combine the tapered spectra into PSD. + Only used in 'multitaper' mode. + mt_low_bias : bool + Only use tapers with more than 90 percent spectral concentration + within bandwidth. Only used in 'multitaper' mode. + cwt_freqs : array + Array of frequencies of interest. Only used in 'cwt_morlet' mode. + cwt_n_cycles : float | array of float + Number of cycles. Fixed number or one per frequency. Only used in + 'cwt_morlet' mode. + gc_n_lags : int + Number of lags to use for the vector autoregressive model when + computing Granger causality. Higher values increase computational cost, + but reduce the degree of spectral smoothing in the results. Only used + if ``method`` contains any of ``['gc', 'gc_tr']``. + rank : tuple of array | None + Two arrays with the rank to project the seed and target data to, + respectively, using singular value decomposition. If None, the rank of + the data is computed and projected to. Only used if ``method`` contains + any of ``['mic', 'mim', 'gc', 'gc_tr']``. + block_size : int + How many connections to compute at once (higher numbers are faster + but require more memory). + n_jobs : int + How many samples to process in parallel. + %(verbose)s + + Returns + ------- + con : array | list of array + Computed connectivity measure(s). Either an instance of + ``SpectralConnectivity`` or ``SpectroTemporalConnectivity``. + The shape of the connectivity result will be: + + - ``(n_cons, n_freqs)`` for multitaper or fourier modes + - ``(n_cons, n_freqs, n_times)`` for cwt_morlet mode + - ``n_cons = n_signals ** 2`` for bivariate methods with + ``indices=None`` + - ``n_cons = 1`` for multivariate methods with ``indices=None`` + - ``n_cons = len(indices[0])`` for bivariate and multivariate methods + when indices is supplied. + + See Also + -------- + mne_connectivity.spectral_connectivity_time + mne_connectivity.SpectralConnectivity + mne_connectivity.SpectroTemporalConnectivity + + Notes + ----- + Please note that the interpretation of the measures in this function + depends on the data and underlying assumptions and does not necessarily + reflect a causal relationship between brain regions. + + These measures are not to be interpreted over time. Each Epoch passed into + the dataset is interpreted as an independent sample of the same + connectivity structure. Within each Epoch, it is assumed that the spectral + measure is stationary. The spectral measures implemented in this function + are computed across Epochs. **Thus, spectral measures computed with only + one Epoch will result in errorful values and spectral measures computed + with few Epochs will be unreliable.** Please see + ``spectral_connectivity_time`` for time-resolved connectivity estimation. + + The spectral densities can be estimated using a multitaper method with + digital prolate spheroidal sequence (DPSS) windows, a discrete Fourier + transform with Hanning windows, or a continuous wavelet transform using + Morlet wavelets. The spectral estimation mode is specified using the + "mode" parameter. + + By default, the connectivity between all signals is computed (only + connections corresponding to the lower-triangular part of the connectivity + matrix). If one is only interested in the connectivity between some + signals, the "indices" parameter can be used. For example, to compute the + connectivity between the signal with index 0 and signals "2, 3, 4" (a total + of 3 connections) one can use the following:: + + indices = (np.array([0, 0, 0]), # row indices + np.array([2, 3, 4])) # col indices + + con = spectral_connectivity_epochs(data, method='coh', + indices=indices, ...) + + In this case con.get_data().shape = (3, n_freqs). The connectivity scores + are in the same order as defined indices. + + For multivariate methods, this is handled differently. If "indices" is + None, connectivity between all signals will be computed and a single + connectivity spectrum will be returned (this is not possible if a Granger + causality method is called). If "indices" is specified, seed and target + indices for each connection should be specified as nested array-likes. For + example, to compute the connectivity between signals (0, 1) -> (2, 3) and + (0, 1) -> (4, 5), indices should be specified as:: + + indices = (np.array([[0, 1], [0, 1]]), # seeds + np.array([[2, 3], [4, 5]])) # targets + + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. + + **Supported Connectivity Measures** + + The connectivity method(s) is specified using the "method" parameter. The + following methods are supported (note: ``E[]`` denotes average over + epochs). Multiple measures can be computed at once by using a list/tuple, + e.g., ``['coh', 'pli']`` to compute coherence and PLI. + + 'coh' : Coherence given by:: + + | E[Sxy] | + C = --------------------- + sqrt(E[Sxx] * E[Syy]) + + 'cohy' : Coherency given by:: + + E[Sxy] + C = --------------------- + sqrt(E[Sxx] * E[Syy]) + + 'imcoh' : Imaginary coherence :footcite:`NolteEtAl2004` given by:: + + Im(E[Sxy]) + C = ---------------------- + sqrt(E[Sxx] * E[Syy]) + + 'mic' : Maximised Imaginary part of Coherency (MIC) + :footcite:`EwaldEtAl2012` given by: + + :math:`MIC=\Large{\frac{\boldsymbol{\alpha}^T \boldsymbol{E \beta}} + {\parallel\boldsymbol{\alpha}\parallel \parallel\boldsymbol{\beta} + \parallel}}` + + where: :math:`\boldsymbol{E}` is the imaginary part of the + transformed cross-spectral density between seeds and targets; and + :math:`\boldsymbol{\alpha}` and :math:`\boldsymbol{\beta}` are + eigenvectors for the seeds and targets, such that + :math:`\boldsymbol{\alpha}^T \boldsymbol{E \beta}` maximises + connectivity between the seeds and targets. + + 'mim' : Multivariate Interaction Measure (MIM) + :footcite:`EwaldEtAl2012` given by: + + :math:`MIM=tr(\boldsymbol{EE}^T)` + + 'plv' : Phase-Locking Value (PLV) :footcite:`LachauxEtAl1999` given + by:: + + PLV = |E[Sxy/|Sxy|]| + + 'ciplv' : corrected imaginary PLV (ciPLV) + :footcite:`BrunaEtAl2018` given by:: + + |E[Im(Sxy/|Sxy|)]| + ciPLV = ------------------------------------ + sqrt(1 - |E[real(Sxy/|Sxy|)]| ** 2) + + 'ppc' : Pairwise Phase Consistency (PPC), an unbiased estimator + of squared PLV :footcite:`VinckEtAl2010`. + + 'pli' : Phase Lag Index (PLI) :footcite:`StamEtAl2007` given by:: + + PLI = |E[sign(Im(Sxy))]| + + 'pli2_unbiased' : Unbiased estimator of squared PLI + :footcite:`VinckEtAl2011`. + + 'dpli' : Directed Phase Lag Index (DPLI) :footcite:`StamEtAl2012` + given by (where H is the Heaviside function):: + DPLI = E[H(Im(Sxy))] -def _check_spectral_connectivity_epochs_settings(method, fmin, fmax, n_jobs, - verbose, con_method_map): - """Check settings inputs for spectral_connectivity_epochs... functions.""" + 'wpli' : Weighted Phase Lag Index (WPLI) :footcite:`VinckEtAl2011` + given by:: + + |E[Im(Sxy)]| + WPLI = ------------------ + E[|Im(Sxy)|] + + 'wpli2_debiased' : Debiased estimator of squared WPLI + :footcite:`VinckEtAl2011`. + + 'gc' : State-space Granger Causality (GC) :footcite:`BarnettSeth2015` + given by: + + :math:`GC = ln\Large{(\frac{\lvert\boldsymbol{S}_{tt}\rvert}{\lvert + \boldsymbol{S}_{tt}-\boldsymbol{H}_{ts}\boldsymbol{\Sigma}_{ss + \lvert t}\boldsymbol{H}_{ts}^*\rvert}})`, + + where: :math:`s` and :math:`t` represent the seeds and targets, + respectively; :math:`\boldsymbol{H}` is the spectral transfer + function; :math:`\boldsymbol{\Sigma}` is the residuals matrix of + the autoregressive model; and :math:`\boldsymbol{S}` is + :math:`\boldsymbol{\Sigma}` transformed by :math:`\boldsymbol{H}`. + + 'gc_tr' : State-space GC on time-reversed signals + :footcite:`BarnettSeth2015,WinklerEtAl2016` given by the same equation + as for 'gc', but where the autocovariance sequence from which the + autoregressive model is produced is transposed to mimic the reversal of + the original signal in time. + + References + ---------- + .. footbibliography:: + """ if n_jobs != 1: parallel, my_epoch_spectral_connectivity, _ = parallel_func( _epoch_spectral_connectivity, n_jobs, verbose=verbose) - else: - parallel = None - my_epoch_spectral_connectivity = None # format fmin and fmax and check inputs if fmin is None: @@ -506,22 +1827,34 @@ def _check_spectral_connectivity_epochs_settings(method, fmin, fmax, n_jobs, raise ValueError('fmin and fmax must have the same length') if np.any(fmin > fmax): raise ValueError('fmax must be larger than fmin') + n_bands = len(fmin) # assign names to connectivity methods if not isinstance(method, (list, tuple)): method = [method] # make it a list so we can iterate over it - # handle connectivity estimators - con_method_types, accumulate_psd = _check_estimators(method, - con_method_map) - - return (fmin, fmax, n_bands, method, con_method_types, accumulate_psd, - parallel, my_epoch_spectral_connectivity) + if n_bands != 1 and any( + this_method in _gc_methods for this_method in method + ): + raise ValueError('computing Granger causality on multiple frequency ' + 'bands is not yet supported') + + if any(this_method in _multivariate_methods for this_method in method): + if not all(this_method in _multivariate_methods for + this_method in method): + raise ValueError( + 'bivariate and multivariate connectivity methods cannot be ' + 'used in the same function call') + multivariate_con = True + else: + multivariate_con = False + # handle connectivity estimators + (con_method_types, n_methods, accumulate_psd) = _check_estimators(method) -def _check_spectral_connectivity_epochs_data(data, sfreq, names): - """Check data inputs for spectral_connectivity_epochs... functions.""" + events = None + event_id = None if isinstance(data, BaseEpochs): names = data.ch_names times_in = data.times # input times for Epochs input type @@ -543,23 +1876,208 @@ def _check_spectral_connectivity_epochs_data(data, sfreq, names): data.add_annotations_to_metadata(overwrite=True) metadata = data.metadata else: - events = None - event_id = None times_in = None metadata = None if sfreq is None: raise ValueError('Sampling frequency (sfreq) is required with ' 'array input.') - return (names, times_in, sfreq, events, event_id, metadata) + # loop over data; it could be a generator that returns + # (n_signals x n_times) arrays or SourceEstimates + epoch_idx = 0 + logger.info('Connectivity computation...') + warn_times = True + for epoch_block in _get_n_epochs(data, n_jobs): + if epoch_idx == 0: + # initialize everything times and frequencies + (n_cons, times, n_times, times_in, n_times_in, tmin_idx, + tmax_idx, n_freqs, freq_mask, freqs, freqs_bands, freq_idx_bands, + n_signals, indices_use, warn_times) = _prepare_connectivity( + epoch_block=epoch_block, times_in=times_in, + tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, sfreq=sfreq, + indices=indices, method=method, mode=mode, fskip=fskip, + n_bands=n_bands, cwt_freqs=cwt_freqs, faverage=faverage) + + # check rank input and compute data ranks if necessary + if multivariate_con: + rank = _check_rank_input(rank, data, indices_use) + else: + rank = None + gc_n_lags = None + + # make sure padded indices are stored in the connectivity object + if multivariate_con and indices is not None: + indices = tuple(np.array(indices_use)) # create a copy + + # get the window function, wavelets, etc for different modes + (spectral_params, mt_adaptive, n_times_spectrum, + n_tapers) = _assemble_spectral_params( + mode=mode, n_times=n_times, mt_adaptive=mt_adaptive, + mt_bandwidth=mt_bandwidth, sfreq=sfreq, + mt_low_bias=mt_low_bias, cwt_n_cycles=cwt_n_cycles, + cwt_freqs=cwt_freqs, freqs=freqs, freq_mask=freq_mask) + + # unique signals for which we actually need to compute PSD etc. + if multivariate_con: + sig_idx = np.unique(np.concatenate(np.concatenate( + indices_use))) + sig_idx = sig_idx[sig_idx != -1] + remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(sig_idx)} + remapping[-1] = -1 + remapped_inds = (indices_use[0].copy(), indices_use[1].copy()) + con_i = 0 + for seed, target in zip(indices_use[0], indices_use[1]): + remapped_inds[0][con_i] = np.array([ + remapping[idx] for idx in seed]) + remapped_inds[1][con_i] = np.array([ + remapping[idx] for idx in target]) + con_i += 1 + remapped_sig = [remapping[idx] for idx in sig_idx] + else: + sig_idx = np.unique(np.r_[indices_use[0], indices_use[1]]) + n_signals_use = len(sig_idx) + + # map indices to unique indices + if multivariate_con: + indices_use = remapped_inds # use remapped seeds & targets + idx_map = [np.sort(np.repeat(remapped_sig, len(sig_idx))), + np.tile(remapped_sig, len(sig_idx))] + else: + idx_map = [ + np.searchsorted(sig_idx, ind) for ind in indices_use] + # allocate space to accumulate PSD + if accumulate_psd: + if n_times_spectrum == 0: + psd_shape = (n_signals_use, n_freqs) + else: + psd_shape = (n_signals_use, n_freqs, n_times_spectrum) + psd = np.zeros(psd_shape) + else: + psd = None + + # create instances of the connectivity estimators + con_methods = [] + for mtype_i, mtype in enumerate(con_method_types): + method_params = dict(n_cons=n_cons, n_freqs=n_freqs, + n_times=n_times_spectrum) + if method[mtype_i] in _multivariate_methods: + method_params.update(dict(n_signals=n_signals_use)) + if method[mtype_i] in _gc_methods: + method_params.update(dict(n_lags=gc_n_lags)) + con_methods.append(mtype(**method_params)) + + sep = ', ' + metrics_str = sep.join([meth.name for meth in con_methods]) + logger.info(' the following metrics will be computed: %s' + % metrics_str) + + # check dimensions and time scale + for this_epoch in epoch_block: + _, _, _, warn_times = _get_and_verify_data_sizes( + this_epoch, sfreq, n_signals, n_times_in, times_in, + warn_times=warn_times) + + call_params = dict( + sig_idx=sig_idx, tmin_idx=tmin_idx, tmax_idx=tmax_idx, sfreq=sfreq, + method=method, mode=mode, freq_mask=freq_mask, idx_map=idx_map, + n_cons=n_cons, block_size=block_size, + psd=psd, accumulate_psd=accumulate_psd, + mt_adaptive=mt_adaptive, + con_method_types=con_method_types, + con_methods=con_methods if n_jobs == 1 else None, + n_signals=n_signals, n_signals_use=n_signals_use, n_times=n_times, + gc_n_lags=gc_n_lags, + accumulate_inplace=True if n_jobs == 1 else False) + call_params.update(**spectral_params) + + if n_jobs == 1: + # no parallel processing + for this_epoch in epoch_block: + logger.info(' computing cross-spectral density for epoch %d' + % (epoch_idx + 1)) + # con methods and psd are updated inplace + _epoch_spectral_connectivity(data=this_epoch, **call_params) + epoch_idx += 1 + else: + # process epochs in parallel + logger.info( + ' computing cross-spectral density for epochs %d..%d' + % (epoch_idx + 1, epoch_idx + len(epoch_block))) + + out = parallel(my_epoch_spectral_connectivity( + data=this_epoch, **call_params) + for this_epoch in epoch_block) + # do the accumulation + for this_out in out: + for _method, parallel_method in zip(con_methods, this_out[0]): + _method.combine(parallel_method) + if accumulate_psd: + psd += this_out[1] + + epoch_idx += len(epoch_block) + + # normalize + n_epochs = epoch_idx + if accumulate_psd: + psd /= n_epochs + + # compute final connectivity scores + con = list() + patterns = list() + for method_i, conn_method in enumerate(con_methods): + + # future estimators will need to be handled here + if conn_method.accumulate_psd: + # compute scores block-wise to save memory + for i in range(0, n_cons, block_size): + con_idx = slice(i, i + block_size) + psd_xx = psd[idx_map[0][con_idx]] + psd_yy = psd[idx_map[1][con_idx]] + conn_method.compute_con(con_idx, n_epochs, psd_xx, psd_yy) + else: + # compute all scores at once + if method[method_i] in _multivariate_methods: + conn_method.compute_con(indices_use, rank, n_epochs) + else: + conn_method.compute_con(slice(0, n_cons), n_epochs) + + # get the connectivity scores + this_con = conn_method.con_scores + this_patterns = conn_method.patterns + + if this_con.shape[0] != n_cons: + raise RuntimeError( + 'first dimension of connectivity scores does not match the ' + 'number of connections; please contact the mne-connectivity ' + 'developers') + if faverage: + if this_con.shape[1] != n_freqs: + raise RuntimeError( + 'second dimension of connectivity scores does not match ' + 'the number of frequencies; please contact the ' + 'mne-connectivity developers') + con_shape = (n_cons, n_bands) + this_con.shape[2:] + this_con_bands = np.empty(con_shape, dtype=this_con.dtype) + for band_idx in range(n_bands): + this_con_bands[:, band_idx] = np.mean( + this_con[:, freq_idx_bands[band_idx]], axis=1) + this_con = this_con_bands + + if this_patterns is not None: + patterns_shape = list(this_patterns.shape) + patterns_shape[3] = n_bands + this_patterns_bands = np.empty(patterns_shape, + dtype=this_patterns.dtype) + for band_idx in range(n_bands): + this_patterns_bands[:, :, :, band_idx] = np.mean( + this_patterns[:, :, :, freq_idx_bands[band_idx]], + axis=3) + this_patterns = this_patterns_bands + + con.append(this_con) + patterns.append(this_patterns) -def _store_results( - con, patterns, method, freqs, faverage, freqs_bands, names, mode, indices, - n_epochs, times, n_tapers, metadata, events, event_id, rank, gc_n_lags, - n_signals -): - """Store results in connectivity containers.""" freqs_used = freqs if faverage: # for each band we return the frequencies that were averaged @@ -572,6 +2090,23 @@ def _store_results( freqs_used = freqs_bands freqs_used = [[np.min(band), np.max(band)] for band in freqs_used] + if indices is None and not multivariate_con: + # return all-to-all connectivity matrices + # raveled into a 1D array + logger.info(' assembling connectivity matrix') + con_flat = con + con = list() + for this_con_flat in con_flat: + this_con = np.zeros((n_signals, n_signals) + + this_con_flat.shape[1:], + dtype=this_con_flat.dtype) + this_con[indices_use] = this_con_flat + + # ravel 2D connectivity into a 1D array + # while keeping other dimensions + this_con = this_con.reshape((n_signals ** 2,) + + this_con_flat.shape[1:]) + con.append(this_con) # number of nodes in the original data n_nodes = n_signals @@ -596,7 +2131,7 @@ def _store_results( logger.info('[Connectivity computation done]') - if len(method) == 1: + if n_methods == 1: # for a single method return connectivity directly conn_list = conn_list[0] diff --git a/mne_connectivity/spectral/epochs_bivariate.py b/mne_connectivity/spectral/epochs_bivariate.py deleted file mode 100644 index 044de3b4..00000000 --- a/mne_connectivity/spectral/epochs_bivariate.py +++ /dev/null @@ -1,729 +0,0 @@ -# Authors: Martin Luessi -# Denis A. Engemann -# Adam Li -# Thomas S. Binns -# -# License: BSD (3-clause) - -import numpy as np -from mne.utils import logger, verbose - -from .epochs import ( - _AbstractConEstBase, _check_spectral_connectivity_epochs_settings, - _check_spectral_connectivity_epochs_data, _get_n_epochs, - _prepare_connectivity, _assemble_spectral_params, - _compute_spectral_methods_epochs, _store_results) -from ..utils import fill_doc, check_indices - - -def _check_indices(indices, n_signals): - if indices is None: - logger.info('only using indices for lower-triangular matrix') - # only compute r for lower-triangular region - indices_use = np.tril_indices(n_signals, -1) - else: - indices_use = check_indices(indices) - - # number of connectivities to compute - n_cons = len(indices_use[0]) - logger.info(' computing connectivity for %d connections' % n_cons) - - return n_cons, indices_use - - -######################################################################## -# Bivariate connectivity estimators - - -class _EpochMeanConEstBase(_AbstractConEstBase): - """Base class for methods that estimate connectivity as mean epoch-wise.""" - - patterns = None - - def __init__(self, n_cons, n_freqs, n_times): - self.n_cons = n_cons - self.n_freqs = n_freqs - self.n_times = n_times - - if n_times == 0: - self.csd_shape = (n_cons, n_freqs) - else: - self.csd_shape = (n_cons, n_freqs, n_times) - - self.con_scores = None - - def start_epoch(self): # noqa: D401 - """Called at the start of each epoch.""" - pass # for this type of con. method we don't do anything - - def combine(self, other): - """Include con. accumated for some epochs in this estimate.""" - self._acc += other._acc - - -class _CohEstBase(_EpochMeanConEstBase): - """Base Estimator for Coherence, Coherency, Imag. Coherence.""" - - accumulate_psd = True - - def __init__(self, n_cons, n_freqs, n_times): - super(_CohEstBase, self).__init__(n_cons, n_freqs, n_times) - - # allocate space for accumulation of CSD - self._acc = np.zeros(self.csd_shape, dtype=np.complex128) - - def accumulate(self, con_idx, csd_xy): - """Accumulate CSD for some connections.""" - self._acc[con_idx] += csd_xy - - -class _CohEst(_CohEstBase): - """Coherence Estimator.""" - - name = 'Coherence' - - def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - csd_mean = self._acc[con_idx] / n_epochs - self.con_scores[con_idx] = np.abs(csd_mean) / np.sqrt(psd_xx * psd_yy) - - -class _CohyEst(_CohEstBase): - """Coherency Estimator.""" - - name = 'Coherency' - - def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape, - dtype=np.complex128) - csd_mean = self._acc[con_idx] / n_epochs - self.con_scores[con_idx] = csd_mean / np.sqrt(psd_xx * psd_yy) - - -class _ImCohEst(_CohEstBase): - """Imaginary Coherence Estimator.""" - - name = 'Imaginary Coherence' - - def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - csd_mean = self._acc[con_idx] / n_epochs - self.con_scores[con_idx] = np.imag(csd_mean) / np.sqrt(psd_xx * psd_yy) - - -class _PLVEst(_EpochMeanConEstBase): - """PLV Estimator.""" - - name = 'PLV' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_PLVEst, self).__init__(n_cons, n_freqs, n_times) - - # allocate accumulator - self._acc = np.zeros(self.csd_shape, dtype=np.complex128) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - self._acc[con_idx] += csd_xy / np.abs(csd_xy) - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - plv = np.abs(self._acc / n_epochs) - self.con_scores[con_idx] = plv - - -class _ciPLVEst(_EpochMeanConEstBase): - """corrected imaginary PLV Estimator.""" - - name = 'ciPLV' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_ciPLVEst, self).__init__(n_cons, n_freqs, n_times) - - # allocate accumulator - self._acc = np.zeros(self.csd_shape, dtype=np.complex128) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - self._acc[con_idx] += csd_xy / np.abs(csd_xy) - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - imag_plv = np.abs(np.imag(self._acc)) / n_epochs - real_plv = np.real(self._acc) / n_epochs - real_plv = np.clip(real_plv, -1, 1) # bounded from -1 to 1 - mask = (np.abs(real_plv) == 1) # avoid division by 0 - real_plv[mask] = 0 - corrected_imag_plv = imag_plv / np.sqrt(1 - real_plv ** 2) - self.con_scores[con_idx] = corrected_imag_plv - - -class _PLIEst(_EpochMeanConEstBase): - """PLI Estimator.""" - - name = 'PLI' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_PLIEst, self).__init__(n_cons, n_freqs, n_times) - - # allocate accumulator - self._acc = np.zeros(self.csd_shape) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - self._acc[con_idx] += np.sign(np.imag(csd_xy)) - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - pli_mean = self._acc[con_idx] / n_epochs - self.con_scores[con_idx] = np.abs(pli_mean) - - -class _PLIUnbiasedEst(_PLIEst): - """Unbiased PLI Square Estimator.""" - - name = 'Unbiased PLI Square' - accumulate_psd = False - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - pli_mean = self._acc[con_idx] / n_epochs - - # See Vinck paper Eq. (30) - con = (n_epochs * pli_mean ** 2 - 1) / (n_epochs - 1) - - self.con_scores[con_idx] = con - - -class _DPLIEst(_EpochMeanConEstBase): - """DPLI Estimator.""" - - name = 'DPLI' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_DPLIEst, self).__init__(n_cons, n_freqs, n_times) - - # allocate accumulator - self._acc = np.zeros(self.csd_shape) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - self._acc[con_idx] += np.heaviside(np.imag(csd_xy), 0.5) - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - - con = self._acc[con_idx] / n_epochs - - self.con_scores[con_idx] = con - - -class _WPLIEst(_EpochMeanConEstBase): - """WPLI Estimator.""" - - name = 'WPLI' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_WPLIEst, self).__init__(n_cons, n_freqs, n_times) - - # store both imag(csd) and abs(imag(csd)) - acc_shape = (2,) + self.csd_shape - self._acc = np.zeros(acc_shape) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - im_csd = np.imag(csd_xy) - self._acc[0, con_idx] += im_csd - self._acc[1, con_idx] += np.abs(im_csd) - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - - num = np.abs(self._acc[0, con_idx]) - denom = self._acc[1, con_idx] - - # handle zeros in denominator - z_denom = np.where(denom == 0.) - denom[z_denom] = 1. - - con = num / denom - - # where we had zeros in denominator, we set con to zero - con[z_denom] = 0. - - self.con_scores[con_idx] = con - - -class _WPLIDebiasedEst(_EpochMeanConEstBase): - """Debiased WPLI Square Estimator.""" - - name = 'Debiased WPLI Square' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_WPLIDebiasedEst, self).__init__(n_cons, n_freqs, n_times) - # store imag(csd), abs(imag(csd)), imag(csd)^2 - acc_shape = (3,) + self.csd_shape - self._acc = np.zeros(acc_shape) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - im_csd = np.imag(csd_xy) - self._acc[0, con_idx] += im_csd - self._acc[1, con_idx] += np.abs(im_csd) - self._acc[2, con_idx] += im_csd ** 2 - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - - # note: we use the trick from fieldtrip to compute the - # the estimate over all pairwise epoch combinations - sum_im_csd = self._acc[0, con_idx] - sum_abs_im_csd = self._acc[1, con_idx] - sum_sq_im_csd = self._acc[2, con_idx] - - denom = sum_abs_im_csd ** 2 - sum_sq_im_csd - - # handle zeros in denominator - z_denom = np.where(denom == 0.) - denom[z_denom] = 1. - - con = (sum_im_csd ** 2 - sum_sq_im_csd) / denom - - # where we had zeros in denominator, we set con to zero - con[z_denom] = 0. - - self.con_scores[con_idx] = con - - -class _PPCEst(_EpochMeanConEstBase): - """Pairwise Phase Consistency (PPC) Estimator.""" - - name = 'PPC' - accumulate_psd = False - - def __init__(self, n_cons, n_freqs, n_times): - super(_PPCEst, self).__init__(n_cons, n_freqs, n_times) - - # store csd / abs(csd) - self._acc = np.zeros(self.csd_shape, dtype=np.complex128) - - def accumulate(self, con_idx, csd_xy): - """Accumulate some connections.""" - denom = np.abs(csd_xy) - z_denom = np.where(denom == 0.) - denom[z_denom] = 1. - this_acc = csd_xy / denom - this_acc[z_denom] = 0. # handle division by zero - - self._acc[con_idx] += this_acc - - def compute_con(self, con_idx, n_epochs): - """Compute final con. score for some connections.""" - if self.con_scores is None: - self.con_scores = np.zeros(self.csd_shape) - - # note: we use the trick from fieldtrip to compute the - # the estimate over all pairwise epoch combinations - con = ((self._acc[con_idx] * np.conj(self._acc[con_idx]) - n_epochs) / - (n_epochs * (n_epochs - 1.))) - - self.con_scores[con_idx] = np.real(con) - - -############################################################################### - - -# map names to estimator types -_CON_METHOD_MAP = {'coh': _CohEst, 'cohy': _CohyEst, 'imcoh': _ImCohEst, - 'plv': _PLVEst, 'ciplv': _ciPLVEst, 'ppc': _PPCEst, - 'pli': _PLIEst, 'pli2_unbiased': _PLIUnbiasedEst, - 'dpli': _DPLIEst, 'wpli': _WPLIEst, - 'wpli2_debiased': _WPLIDebiasedEst} - - -@ verbose -@ fill_doc -def spectral_connectivity_epochs( - data, names=None, method='coh', indices=None, sfreq=None, - mode='multitaper', fmin=None, fmax=np.inf, fskip=0, faverage=False, - tmin=None, tmax=None, mt_bandwidth=None, mt_adaptive=False, - mt_low_bias=True, cwt_freqs=None, cwt_n_cycles=7, block_size=1000, - n_jobs=1, verbose=None -): - """Compute bivariate (time-)frequency-domain connectivity measures. - - The connectivity method(s) are specified using the "method" parameter. - All methods are based on estimates of the cross- and power spectral - densities (CSD/PSD) Sxy and Sxx, Syy. - - Parameters - ---------- - data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs - The data from which to compute connectivity. Note that it is also - possible to combine multiple signals by providing a list of tuples, - e.g., data = [(arr_0, stc_0), (arr_1, stc_1), (arr_2, stc_2)], - corresponds to 3 epochs, and arr_* could be an array with the same - number of time points as stc_*. The array-like object can also - be a list/generator of array, shape =(n_signals, n_times), - or a list/generator of SourceEstimate or VolSourceEstimate objects. - %(names)s - method : str | list of str - Connectivity measure(s) to compute. These can be ``['coh', 'cohy', - 'imcoh', 'plv', 'ciplv', 'ppc', 'pli', 'dpli', 'wpli', - 'wpli2_debiased']``. - indices : tuple of array | None - Two arrays with indices of connections for which to compute - connectivity. Each array for the seeds and targets should contain the - channel indices for each bivariate connection. If ``None``, connections - between all channels are computed. - sfreq : float - The sampling frequency. Required if data is not - :class:`Epochs `. - mode : str - Spectrum estimation mode can be either: 'multitaper', 'fourier', or - 'cwt_morlet'. - fmin : float | tuple of float - The lower frequency of interest. Multiple bands are defined using - a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. - fmax : float | tuple of float - The upper frequency of interest. Multiple bands are dedined using - a tuple, e.g. (13., 30.) for two band with 13Hz and 30Hz upper freq. - fskip : int - Omit every "(fskip + 1)-th" frequency bin to decimate in frequency - domain. - faverage : bool - Average connectivity scores for each frequency band. If True, - the output freqs will be a list with arrays of the frequencies - that were averaged. - tmin : float | None - Time to start connectivity estimation. Note: when "data" is an array, - the first sample is assumed to be at time 0. For other types - (Epochs, etc.), the time information contained in the object is used - to compute the time indices. - tmax : float | None - Time to end connectivity estimation. Note: when "data" is an array, - the first sample is assumed to be at time 0. For other types - (Epochs, etc.), the time information contained in the object is used - to compute the time indices. - mt_bandwidth : float | None - The bandwidth of the multitaper windowing function in Hz. - Only used in 'multitaper' mode. - mt_adaptive : bool - Use adaptive weights to combine the tapered spectra into PSD. - Only used in 'multitaper' mode. - mt_low_bias : bool - Only use tapers with more than 90 percent spectral concentration - within bandwidth. Only used in 'multitaper' mode. - cwt_freqs : array - Array of frequencies of interest. Only used in 'cwt_morlet' mode. - cwt_n_cycles : float | array of float - Number of cycles. Fixed number or one per frequency. Only used in - 'cwt_morlet' mode. - block_size : int - How many connections to compute at once (higher numbers are faster - but require more memory). - n_jobs : int - How many samples to process in parallel. - %(verbose)s - - Returns - ------- - con : array | list of array - Computed connectivity measure(s). Either an instance of - ``SpectralConnectivity`` or ``SpectroTemporalConnectivity``. - The shape of the connectivity result will be: - - - ``(n_cons, n_freqs)`` for multitaper or fourier modes - - ``(n_cons, n_freqs, n_times)`` for cwt_morlet mode - - ``n_cons = n_signals ** 2`` with ``indices=None`` - - ``n_cons = len(indices[0])`` when indices is supplied. - - See Also - -------- - mne_connectivity.spectral_connectivity_epochs_multivariate - mne_connectivity.spectral_connectivity_time - mne_connectivity.SpectralConnectivity - mne_connectivity.SpectroTemporalConnectivity - - Notes - ----- - Please note that the interpretation of the measures in this function - depends on the data and underlying assumptions and does not necessarily - reflect a causal relationship between brain regions. - - These measures are not to be interpreted over time. Each Epoch passed into - the dataset is interpreted as an independent sample of the same - connectivity structure. Within each Epoch, it is assumed that the spectral - measure is stationary. The spectral measures implemented in this function - are computed across Epochs. **Thus, spectral measures computed with only - one Epoch will result in errorful values and spectral measures computed - with few Epochs will be unreliable.** Please see - ``spectral_connectivity_time`` for time-resolved connectivity estimation. - - The spectral densities can be estimated using a multitaper method with - digital prolate spheroidal sequence (DPSS) windows, a discrete Fourier - transform with Hanning windows, or a continuous wavelet transform using - Morlet wavelets. The spectral estimation mode is specified using the - "mode" parameter. - - By default, the connectivity between all signals is computed (only - connections corresponding to the lower-triangular part of the connectivity - matrix). If one is only interested in the connectivity between some - signals, the "indices" parameter can be used. For example, to compute the - connectivity between the signal with index 0 and signals "2, 3, 4" (a total - of 3 connections) one can use the following:: - - indices = (np.array([0, 0, 0]), # row indices - np.array([2, 3, 4])) # col indices - - con = spectral_connectivity_epochs(data, method='coh', - indices=indices, ...) - - In this case con.get_data().shape = (3, n_freqs). The connectivity scores - are in the same order as defined indices. - - **Supported Connectivity Measures** - - The connectivity method(s) is specified using the "method" parameter. The - following methods are supported (note: ``E[]`` denotes average over - epochs). Multiple measures can be computed at once by using a list/tuple, - e.g., ``['coh', 'pli']`` to compute coherence and PLI. - - 'coh' : Coherence given by:: - - | E[Sxy] | - C = --------------------- - sqrt(E[Sxx] * E[Syy]) - - 'cohy' : Coherency given by:: - - E[Sxy] - C = --------------------- - sqrt(E[Sxx] * E[Syy]) - - 'imcoh' : Imaginary coherence :footcite:`NolteEtAl2004` given by:: - - Im(E[Sxy]) - C = ---------------------- - sqrt(E[Sxx] * E[Syy]) - - 'plv' : Phase-Locking Value (PLV) :footcite:`LachauxEtAl1999` given - by:: - - PLV = |E[Sxy/|Sxy|]| - - 'ciplv' : corrected imaginary PLV (ciPLV) - :footcite:`BrunaEtAl2018` given by:: - - |E[Im(Sxy/|Sxy|)]| - ciPLV = ------------------------------------ - sqrt(1 - |E[real(Sxy/|Sxy|)]| ** 2) - - 'ppc' : Pairwise Phase Consistency (PPC), an unbiased estimator - of squared PLV :footcite:`VinckEtAl2010`. - - 'pli' : Phase Lag Index (PLI) :footcite:`StamEtAl2007` given by:: - - PLI = |E[sign(Im(Sxy))]| - - 'pli2_unbiased' : Unbiased estimator of squared PLI - :footcite:`VinckEtAl2011`. - - 'dpli' : Directed Phase Lag Index (DPLI) :footcite:`StamEtAl2012` - given by (where H is the Heaviside function):: - - DPLI = E[H(Im(Sxy))] - - 'wpli' : Weighted Phase Lag Index (WPLI) :footcite:`VinckEtAl2011` - given by:: - - |E[Im(Sxy)]| - WPLI = ------------------ - E[|Im(Sxy)|] - - 'wpli2_debiased' : Debiased estimator of squared WPLI - :footcite:`VinckEtAl2011`. - - References - ---------- - .. footbibliography:: - """ - ( - fmin, fmax, n_bands, method, con_method_types, accumulate_psd, - parallel, my_epoch_spectral_connectivity - ) = _check_spectral_connectivity_epochs_settings( - method, fmin, fmax, n_jobs, verbose, _CON_METHOD_MAP) - - (names, times_in, sfreq, events, event_id, - metadata) = _check_spectral_connectivity_epochs_data(data, sfreq, names) - - # loop over data; it could be a generator that returns - # (n_signals x n_times) arrays or SourceEstimates - epoch_idx = 0 - logger.info('Connectivity computation...') - warn_times = True - for epoch_block in _get_n_epochs(data, n_jobs): - if epoch_idx == 0: - # initialize everything times and frequencies - (times, n_times, times_in, n_times_in, tmin_idx, tmax_idx, n_freqs, - freq_mask, freqs, freqs_bands, freq_idx_bands, n_signals, - warn_times) = _prepare_connectivity( - epoch_block=epoch_block, times_in=times_in, tmin=tmin, - tmax=tmax, fmin=fmin, fmax=fmax, sfreq=sfreq, mode=mode, - fskip=fskip, n_bands=n_bands, cwt_freqs=cwt_freqs, - faverage=faverage) - - # check indices input - n_cons, indices_use = _check_indices(indices, n_signals) - - # get the window function, wavelets, etc for different modes - (spectral_params, mt_adaptive, n_times_spectrum, - n_tapers) = _assemble_spectral_params( - mode=mode, n_times=n_times, mt_adaptive=mt_adaptive, - mt_bandwidth=mt_bandwidth, sfreq=sfreq, - mt_low_bias=mt_low_bias, cwt_n_cycles=cwt_n_cycles, - cwt_freqs=cwt_freqs, freqs=freqs, freq_mask=freq_mask) - - # unique signals for which we actually need to compute CSD/PSD - sig_idx = np.unique(np.r_[indices_use[0], indices_use[1]]) - n_signals_use = len(sig_idx) - - # map indices to unique indices - idx_map = [np.searchsorted(sig_idx, ind) for ind in indices_use] - - # allocate space to accumulate PSD - if accumulate_psd: - if n_times_spectrum == 0: - psd_shape = (n_signals_use, n_freqs) - else: - psd_shape = (n_signals_use, n_freqs, n_times_spectrum) - psd = np.zeros(psd_shape) - else: - psd = None - - # create instances of the connectivity estimators - con_methods = [] - for mtype in con_method_types: - con_methods.append(mtype(n_cons=n_cons, n_freqs=n_freqs, - n_times=n_times_spectrum)) - - sep = ', ' - metrics_str = sep.join([meth.name for meth in con_methods]) - logger.info(' the following metrics will be computed: %s' - % metrics_str) - - call_params = dict( - sig_idx=sig_idx, tmin_idx=tmin_idx, tmax_idx=tmax_idx, sfreq=sfreq, - method=method, mode=mode, freq_mask=freq_mask, idx_map=idx_map, - n_cons=n_cons, block_size=block_size, - psd=psd, accumulate_psd=accumulate_psd, - mt_adaptive=mt_adaptive, - con_method_types=con_method_types, - con_methods=con_methods if n_jobs == 1 else None, - n_signals=n_signals, n_signals_use=n_signals_use, n_times=n_times, - gc_n_lags=None, multivariate_con=False, - accumulate_inplace=True if n_jobs == 1 else False) - call_params.update(**spectral_params) - - epoch_idx = _compute_spectral_methods_epochs( - con_methods, epoch_block, epoch_idx, call_params, parallel, - my_epoch_spectral_connectivity, n_jobs, n_times_in, times_in, - warn_times) - - # normalize - n_epochs = epoch_idx - if accumulate_psd: - psd /= n_epochs - - # compute final connectivity scores - con = list() - for conn_method in con_methods: - - # future estimators will need to be handled here - if conn_method.accumulate_psd: - # compute scores block-wise to save memory - for i in range(0, n_cons, block_size): - con_idx = slice(i, i + block_size) - psd_xx = psd[idx_map[0][con_idx]] - psd_yy = psd[idx_map[1][con_idx]] - conn_method.compute_con(con_idx, n_epochs, psd_xx, psd_yy) - else: - # compute all scores at once - conn_method.compute_con(slice(0, n_cons), n_epochs) - - # get the connectivity scores - this_con = conn_method.con_scores - - if this_con.shape[0] != n_cons: - raise RuntimeError( - 'first dimension of connectivity scores does not match the ' - 'number of connections; please contact the mne-connectivity ' - 'developers') - if faverage: - if this_con.shape[1] != n_freqs: - raise RuntimeError( - 'second dimension of connectivity scores does not match ' - 'the number of frequencies; please contact the ' - 'mne-connectivity developers') - con_shape = (n_cons, n_bands) + this_con.shape[2:] - this_con_bands = np.empty(con_shape, dtype=this_con.dtype) - for band_idx in range(n_bands): - this_con_bands[:, band_idx] = np.mean( - this_con[:, freq_idx_bands[band_idx]], axis=1) - this_con = this_con_bands - - con.append(this_con) - # No patterns for bivariate connectivity - patterns = [None for _ in range(len(con))] - - # return all-to-all connectivity matrices raveled into a 1D array - if indices is None: - logger.info(' assembling connectivity matrix') - con_flat = con - con = list() - for this_con_flat in con_flat: - this_con = np.zeros((n_signals, n_signals) + - this_con_flat.shape[1:], - dtype=this_con_flat.dtype) - this_con[indices_use] = this_con_flat - - # ravel 2D connectivity into a 1D array - # while keeping other dimensions - this_con = this_con.reshape((n_signals ** 2,) + - this_con_flat.shape[1:]) - con.append(this_con) - - conn_list = _store_results( - con=con, patterns=patterns, method=method, freqs=freqs, - faverage=faverage, freqs_bands=freqs_bands, names=names, mode=mode, - indices=indices, n_epochs=n_epochs, times=times, n_tapers=n_tapers, - metadata=metadata, events=events, event_id=event_id, rank=None, - gc_n_lags=None, n_signals=n_signals) - - return conn_list diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py deleted file mode 100644 index 28077adb..00000000 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ /dev/null @@ -1,1129 +0,0 @@ -# Authors: Martin Luessi -# Denis A. Engemann -# Adam Li -# Thomas S. Binns -# Tien D. Nguyen -# Richard M. Köhler -# -# License: BSD (3-clause) - -import numpy as np -import scipy as sp -from mne.epochs import BaseEpochs -from mne.parallel import parallel_func -from mne.utils import ProgressBar, logger, verbose - -from .epochs import ( - _AbstractConEstBase, _check_spectral_connectivity_epochs_settings, - _check_spectral_connectivity_epochs_data, _get_n_epochs, - _prepare_connectivity, _assemble_spectral_params, - _compute_spectral_methods_epochs, _store_results) -from ..utils import fill_doc, check_multivariate_indices - - -def _check_indices(indices, method, n_signals): - if indices is None: - if any(this_method in _gc_methods for this_method in method): - raise ValueError( - 'indices must be specified when computing Granger causality, ' - 'as all-to-all connectivity is not supported') - else: - logger.info('using all indices for multivariate connectivity') - indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], - np.arange(n_signals, dtype=int)[np.newaxis, :]) - else: - indices_use = check_multivariate_indices(indices) # pad with -1 - if any(this_method in _gc_methods for this_method in method): - for seed, target in zip(indices[0], indices[1]): - intersection = np.intersect1d(seed, target) - if np.any(intersection != -1): # ignore padded entries - raise ValueError( - 'seed and target indices must not intersect when ' - 'computing Granger causality') - - # number of connectivities to compute - n_cons = len(indices_use[0]) - logger.info(' computing connectivity for %d connections' % n_cons) - - return n_cons, indices_use - - -def _check_rank_input(rank, data, indices): - """Check the rank argument is appropriate and compute rank if missing.""" - sv_tol = 1e-10 # tolerance for non-zero singular val (rel. to largest) - if rank is None: - rank = np.zeros((2, len(indices[0])), dtype=int) - - if isinstance(data, BaseEpochs): - data_arr = data.get_data() - else: - data_arr = data - - # XXX: Unpadding of arrays after already padding them is perhaps not so - # efficient. However, we need to remove the padded values to - # ensure only the correct channels are indexed, and having two - # versions of indices is a bit messy currently. A candidate for - # refactoring to simplify code. - - for group_i in range(2): # seeds and targets - for con_i, con_idcs in enumerate(indices[group_i]): - con_idcs = con_idcs[con_idcs != -1] # -1 is padded value - s = np.linalg.svd(data_arr[:, con_idcs], compute_uv=False) - rank[group_i][con_i] = np.min( - [np.count_nonzero(epoch >= epoch[0] * sv_tol) - for epoch in s]) - - logger.info('Estimated data ranks:') - con_i = 1 - for seed_rank, target_rank in zip(rank[0], rank[1]): - logger.info(' connection %i - seeds (%i); targets (%i)' - % (con_i, seed_rank, target_rank, )) - con_i += 1 - rank = tuple((np.array(rank[0]), np.array(rank[1]))) - - else: - for seed_idcs, target_idcs, seed_rank, target_rank in zip( - indices[0], indices[1], rank[0], rank[1]): - if not (0 < seed_rank <= len(seed_idcs) and - 0 < target_rank <= len(target_idcs)): - raise ValueError( - 'ranks for seeds and targets must be > 0 and <= the ' - 'number of channels in the seeds and targets, ' - 'respectively, for each connection') - - return rank - - -######################################################################## -# Multivariate connectivity estimators - -class _EpochMeanMultivariateConEstBase(_AbstractConEstBase): - """Base class for mean epoch-wise multivar. con. estimation methods.""" - - n_steps = None - patterns = None - - def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): - self.n_signals = n_signals - self.n_cons = n_cons - self.n_freqs = n_freqs - self.n_times = n_times - self.n_jobs = n_jobs - - # include time dimension, even when unused for indexing flexibility - if n_times == 0: - self.csd_shape = (n_signals**2, n_freqs) - self.con_scores = np.zeros((n_cons, n_freqs, 1)) - else: - self.csd_shape = (n_signals**2, n_freqs, n_times) - self.con_scores = np.zeros((n_cons, n_freqs, n_times)) - - # allocate space for accumulation of CSD - self._acc = np.zeros(self.csd_shape, dtype=np.complex128) - - self._compute_n_progress_bar_steps() - - def start_epoch(self): # noqa: D401 - """Called at the start of each epoch.""" - pass # for this type of con. method we don't do anything - - def combine(self, other): - """Include con. accumulated for some epochs in this estimate.""" - self._acc += other._acc - - def accumulate(self, con_idx, csd_xy): - """Accumulate CSD for some connections.""" - self._acc[con_idx] += csd_xy - - def _compute_n_progress_bar_steps(self): - """Calculate the number of steps to include in the progress bar.""" - self.n_steps = int(np.ceil(self.n_freqs / self.n_jobs)) - - def _log_connection_number(self, con_i): - """Log the number of the connection being computed.""" - logger.info('Computing %s for connection %i of %i' - % (self.name, con_i + 1, self.n_cons, )) - - def _get_block_indices(self, block_i, limit): - """Get indices for a computation block capped by a limit.""" - indices = np.arange(block_i * self.n_jobs, (block_i + 1) * self.n_jobs) - - return indices[np.nonzero(indices < limit)] - - def reshape_csd(self): - """Reshape CSD into a matrix of times x freqs x signals x signals.""" - if self.n_times == 0: - return (np.reshape(self._acc, ( - self.n_signals, self.n_signals, self.n_freqs, 1) - ).transpose(3, 2, 0, 1)) - - return (np.reshape(self._acc, ( - self.n_signals, self.n_signals, self.n_freqs, self.n_times) - ).transpose(3, 2, 0, 1)) - - -class _MultivariateCohEstBase(_EpochMeanMultivariateConEstBase): - """Base estimator for multivariate imag. part of coherency methods. - - See Ewald et al. (2012). NeuroImage. DOI: 10.1016/j.neuroimage.2011.11.084 - for equation references. - """ - - name = None - accumulate_psd = False - - def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): - super(_MultivariateCohEstBase, self).__init__( - n_signals, n_cons, n_freqs, n_times, n_jobs) - - def compute_con(self, indices, ranks, n_epochs=1): - """Compute multivariate imag. part of coherency between signals.""" - assert self.name in ['MIC', 'MIM'], ( - 'the class name is not recognised, please contact the ' - 'mne-connectivity developers') - - csd = self.reshape_csd() / n_epochs - n_times = csd.shape[0] - times = np.arange(n_times) - freqs = np.arange(self.n_freqs) - - if self.name == 'MIC': - self.patterns = np.full( - (2, self.n_cons, indices[0].shape[1], self.n_freqs, n_times), - np.nan) - - con_i = 0 - for seed_idcs, target_idcs, seed_rank, target_rank in zip( - indices[0], indices[1], ranks[0], ranks[1]): - self._log_connection_number(con_i) - - seed_idcs = seed_idcs[seed_idcs != -1] - target_idcs = target_idcs[target_idcs != -1] - con_idcs = [*seed_idcs, *target_idcs] - - C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] - - # Eqs. 32 & 33 - C_bar, U_bar_aa, U_bar_bb = self._csd_svd( - C, seed_idcs, seed_rank, target_rank) - - # Eqs. 3 & 4 - E = self._compute_e(C_bar, n_seeds=U_bar_aa.shape[3]) - - if self.name == 'MIC': - self._compute_mic(E, C, seed_idcs, target_idcs, n_times, - U_bar_aa, U_bar_bb, con_i) - else: - self._compute_mim(E, seed_idcs, target_idcs, con_i) - - con_i += 1 - - self.reshape_results() - - def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): - """Dimensionality reduction of CSD with SVD.""" - n_times = csd.shape[0] - n_seeds = len(seed_idcs) - n_targets = csd.shape[3] - n_seeds - - C_aa = csd[..., :n_seeds, :n_seeds] - C_ab = csd[..., :n_seeds, n_seeds:] - C_bb = csd[..., n_seeds:, n_seeds:] - C_ba = csd[..., n_seeds:, :n_seeds] - - # Eq. 32 - if seed_rank != n_seeds: - U_aa = np.linalg.svd(np.real(C_aa), full_matrices=False)[0] - U_bar_aa = U_aa[..., :seed_rank] - else: - U_bar_aa = np.broadcast_to( - np.identity(n_seeds), - (n_times, self.n_freqs) + (n_seeds, n_seeds)) - - if target_rank != n_targets: - U_bb = np.linalg.svd(np.real(C_bb), full_matrices=False)[0] - U_bar_bb = U_bb[..., :target_rank] - else: - U_bar_bb = np.broadcast_to( - np.identity(n_targets), - (n_times, self.n_freqs) + (n_targets, n_targets)) - - # Eq. 33 - C_bar_aa = np.matmul( - U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_aa, U_bar_aa)) - C_bar_ab = np.matmul( - U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_ab, U_bar_bb)) - C_bar_bb = np.matmul( - U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_bb, U_bar_bb)) - C_bar_ba = np.matmul( - U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_ba, U_bar_aa)) - C_bar = np.append(np.append(C_bar_aa, C_bar_ab, axis=3), - np.append(C_bar_ba, C_bar_bb, axis=3), axis=2) - - return C_bar, U_bar_aa, U_bar_bb - - def _compute_e(self, csd, n_seeds): - """Compute E from the CSD.""" - C_r = np.real(csd) - - parallel, parallel_compute_t, _ = parallel_func( - _mic_mim_compute_t, self.n_jobs, verbose=False) - - # imag. part of T filled when data is rank-deficient - T = np.zeros(csd.shape, dtype=np.complex128) - for block_i in ProgressBar( - range(self.n_steps), mesg="frequency blocks"): - freqs = self._get_block_indices(block_i, self.n_freqs) - T[:, freqs] = np.array(parallel(parallel_compute_t( - C_r[:, f], T[:, f], n_seeds) for f in freqs) - ).transpose(1, 0, 2, 3) - - if not np.isreal(T).all() or not np.isfinite(T).all(): - raise RuntimeError( - 'the transformation matrix of the data must be real-valued ' - 'and contain no NaN or infinity values; check that you are ' - 'using full rank data or specify an appropriate rank for the ' - 'seeds and targets that is less than or equal to their ranks') - T = np.real(T) # make T real if check passes - - # Eq. 4 - D = np.matmul(T, np.matmul(csd, T)) - - # E as imag. part of D between seeds and targets - return np.imag(D[..., :n_seeds, n_seeds:]) - - def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, - U_bar_bb, con_i): - """Compute MIC and the associated spatial patterns.""" - n_seeds = len(seed_idcs) - n_targets = len(target_idcs) - times = np.arange(n_times) - freqs = np.arange(self.n_freqs) - - # Eigendecomp. to find spatial filters for seeds and targets - w_seeds, V_seeds = np.linalg.eigh( - np.matmul(E, E.transpose(0, 1, 3, 2))) - w_targets, V_targets = np.linalg.eigh( - np.matmul(E.transpose(0, 1, 3, 2), E)) - if ( - len(seed_idcs) == len(target_idcs) and - np.all(np.sort(seed_idcs) == np.sort(target_idcs)) - ): - # strange edge-case where the eigenvectors returned should be a set - # of identity matrices with one rotated by 90 degrees, but are - # instead identical (i.e. are not rotated versions of one another). - # This leads to the case where the spatial filters are incorrectly - # applied, resulting in connectivity estimates of ~0 when they - # should be perfectly correlated ~1. Accordingly, we manually - # create a set of rotated identity matrices to use as the filters. - create_filter = False - stop = False - while not create_filter and not stop: - for time_i in range(n_times): - for freq_i in range(self.n_freqs): - if np.all(V_seeds[time_i, freq_i] == - V_targets[time_i, freq_i]): - create_filter = True - break - stop = True - if create_filter: - n_chans = E.shape[2] - eye_4d = np.zeros_like(V_seeds) - eye_4d[:, :, np.arange(n_chans), np.arange(n_chans)] = 1 - V_seeds = eye_4d - V_targets = np.rot90(eye_4d, axes=(2, 3)) - - # Spatial filters with largest eigval. for seeds and targets - alpha = V_seeds[times[:, None], freqs, :, w_seeds.argmax(axis=2)] - beta = V_targets[times[:, None], freqs, :, w_targets.argmax(axis=2)] - - # Eq. 46 (seed spatial patterns) - self.patterns[0, con_i, :n_seeds] = (np.matmul( - np.real(C[..., :n_seeds, :n_seeds]), - np.matmul(U_bar_aa, np.expand_dims(alpha, axis=3))))[..., 0].T - - # Eq. 47 (target spatial patterns) - self.patterns[1, con_i, :n_targets] = (np.matmul( - np.real(C[..., n_seeds:, n_seeds:]), - np.matmul(U_bar_bb, np.expand_dims(beta, axis=3))))[..., 0].T - - # Eq. 7 - self.con_scores[con_i] = (np.einsum( - 'ijk,ijk->ij', alpha, np.matmul(E, np.expand_dims( - beta, axis=3))[..., 0] - ) / np.linalg.norm(alpha, axis=2) * np.linalg.norm(beta, axis=2)).T - - def _compute_mim(self, E, seed_idcs, target_idcs, con_i): - """Compute MIM (a.k.a. GIM if seeds == targets).""" - # Eq. 14 - self.con_scores[con_i] = np.matmul( - E, E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T - - # Eq. 15 - if ( - len(seed_idcs) == len(target_idcs) and - np.all(np.sort(seed_idcs) == np.sort(target_idcs)) - ): - self.con_scores[con_i] *= 0.5 - - def reshape_results(self): - """Remove time dimension from results, if necessary.""" - if self.n_times == 0: - self.con_scores = self.con_scores[..., 0] - if self.patterns is not None: - self.patterns = self.patterns[..., 0] - - -def _mic_mim_compute_t(C, T, n_seeds): - """Compute T for a single frequency (used for MIC and MIM).""" - for time_i in range(C.shape[0]): - T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power( - C[time_i, :n_seeds, :n_seeds], -0.5 - ) - T[time_i, n_seeds:, n_seeds:] = sp.linalg.fractional_matrix_power( - C[time_i, n_seeds:, n_seeds:], -0.5 - ) - - return T - - -class _MICEst(_MultivariateCohEstBase): - """Multivariate imaginary part of coherency (MIC) estimator.""" - - name = "MIC" - - -class _MIMEst(_MultivariateCohEstBase): - """Multivariate interaction measure (MIM) estimator.""" - - name = "MIM" - - -class _GCEstBase(_EpochMeanMultivariateConEstBase): - """Base multivariate state-space Granger causality estimator.""" - - accumulate_psd = False - - def __init__(self, n_signals, n_cons, n_freqs, n_times, n_lags, n_jobs=1): - super(_GCEstBase, self).__init__( - n_signals, n_cons, n_freqs, n_times, n_jobs) - - self.freq_res = (self.n_freqs - 1) * 2 - if n_lags >= self.freq_res: - raise ValueError( - 'the number of lags (%i) must be less than double the ' - 'frequency resolution (%i)' % (n_lags, self.freq_res, )) - self.n_lags = n_lags - - def compute_con(self, indices, ranks, n_epochs=1): - """Compute multivariate state-space Granger causality.""" - assert self.name in ['GC', 'GC time-reversed'], ( - 'the class name is not recognised, please contact the ' - 'mne-connectivity developers') - - csd = self.reshape_csd() / n_epochs - - n_times = csd.shape[0] - times = np.arange(n_times) - freqs = np.arange(self.n_freqs) - - con_i = 0 - for seed_idcs, target_idcs, seed_rank, target_rank in zip( - indices[0], indices[1], ranks[0], ranks[1]): - self._log_connection_number(con_i) - - seed_idcs = seed_idcs[seed_idcs != -1] - target_idcs = target_idcs[target_idcs != -1] - con_idcs = [*seed_idcs, *target_idcs] - - C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] - - C_bar = self._csd_svd(C, seed_idcs, seed_rank, target_rank) - n_signals = seed_rank + target_rank - con_seeds = np.arange(seed_rank) - con_targets = np.arange(target_rank) + seed_rank - - autocov = self._compute_autocov(C_bar) - if self.name == "GC time-reversed": - autocov = autocov.transpose(0, 1, 3, 2) - - A_f, V = self._autocov_to_full_var(autocov) - A_f_3d = np.reshape( - A_f, (n_times, n_signals, n_signals * self.n_lags), order="F") - A, K = self._full_var_to_iss(A_f_3d) - - self.con_scores[con_i] = self._iss_to_ugc( - A, A_f_3d, K, V, con_seeds, con_targets) - - con_i += 1 - - self.reshape_results() - - def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): - """Dimensionality reduction of CSD with SVD on the covariance.""" - # sum over times and epochs to get cov. from CSD - cov = csd.sum(axis=(0, 1)) - - n_seeds = len(seed_idcs) - n_targets = csd.shape[3] - n_seeds - - cov_aa = cov[:n_seeds, :n_seeds] - cov_bb = cov[n_seeds:, n_seeds:] - - if seed_rank != n_seeds: - U_aa = np.linalg.svd(np.real(cov_aa), full_matrices=False)[0] - U_bar_aa = U_aa[:, :seed_rank] - else: - U_bar_aa = np.identity(n_seeds) - - if target_rank != n_targets: - U_bb = np.linalg.svd(np.real(cov_bb), full_matrices=False)[0] - U_bar_bb = U_bb[:, :target_rank] - else: - U_bar_bb = np.identity(n_targets) - - C_aa = csd[..., :n_seeds, :n_seeds] - C_ab = csd[..., :n_seeds, n_seeds:] - C_bb = csd[..., n_seeds:, n_seeds:] - C_ba = csd[..., n_seeds:, :n_seeds] - - C_bar_aa = np.matmul( - U_bar_aa.transpose(1, 0), np.matmul(C_aa, U_bar_aa)) - C_bar_ab = np.matmul( - U_bar_aa.transpose(1, 0), np.matmul(C_ab, U_bar_bb)) - C_bar_bb = np.matmul( - U_bar_bb.transpose(1, 0), np.matmul(C_bb, U_bar_bb)) - C_bar_ba = np.matmul( - U_bar_bb.transpose(1, 0), np.matmul(C_ba, U_bar_aa)) - C_bar = np.append(np.append(C_bar_aa, C_bar_ab, axis=3), - np.append(C_bar_ba, C_bar_bb, axis=3), axis=2) - - return C_bar - - def _compute_autocov(self, csd): - """Compute autocovariance from the CSD.""" - n_times = csd.shape[0] - n_signals = csd.shape[2] - - circular_shifted_csd = np.concatenate( - [np.flip(np.conj(csd[:, 1:]), axis=1), csd[:, :-1]], axis=1) - ifft_shifted_csd = self._block_ifft( - circular_shifted_csd, self.freq_res) - lags_ifft_shifted_csd = np.reshape( - ifft_shifted_csd[:, :self.n_lags + 1], - (n_times, self.n_lags + 1, n_signals ** 2), order="F") - - signs = np.repeat([1], self.n_lags + 1).tolist() - signs[1::2] = [x * -1 for x in signs[1::2]] - sign_matrix = np.repeat( - np.tile(np.array(signs), (n_signals ** 2, 1))[np.newaxis], - n_times, axis=0).transpose(0, 2, 1) - - return np.real(np.reshape( - sign_matrix * lags_ifft_shifted_csd, - (n_times, self.n_lags + 1, n_signals, n_signals), order="F")) - - def _block_ifft(self, csd, n_points): - """Compute block iFFT with n points.""" - shape = csd.shape - csd_3d = np.reshape( - csd, (shape[0], shape[1], shape[2] * shape[3]), order="F") - - csd_ifft = np.fft.ifft(csd_3d, n=n_points, axis=1) - - return np.reshape(csd_ifft, shape, order="F") - - def _autocov_to_full_var(self, autocov): - """Compute full VAR model using Whittle's LWR recursion.""" - if np.any(np.linalg.det(autocov) == 0): - raise RuntimeError( - 'the autocovariance matrix is singular; check if your data is ' - 'rank deficient and specify an appropriate rank argument <= ' - 'the rank of the seeds and targets') - - A_f, V = self._whittle_lwr_recursion(autocov) - - if not np.isfinite(A_f).all(): - raise RuntimeError('at least one VAR model coefficient is ' - 'infinite or NaN; check the data you are using') - - try: - np.linalg.cholesky(V) - except np.linalg.LinAlgError as np_error: - raise RuntimeError( - 'the covariance matrix of the residuals is not ' - 'positive-definite; check the singular values of your data ' - 'and specify an appropriate rank argument <= the rank of the ' - 'seeds and targets') from np_error - - return A_f, V - - def _whittle_lwr_recursion(self, G): - """Solve Yule-Walker eqs. for full VAR params. with LWR recursion. - - See: Whittle P., 1963. Biometrika, DOI: 10.1093/biomet/50.1-2.129 - """ - # Initialise recursion - n = G.shape[2] # number of signals - q = G.shape[1] - 1 # number of lags - t = G.shape[0] # number of times - qn = n * q - - cov = G[:, 0, :, :] # covariance - G_f = np.reshape( - G[:, 1:, :, :].transpose(0, 3, 1, 2), (t, qn, n), - order="F") # forward autocov - G_b = np.reshape( - np.flip(G[:, 1:, :, :], 1).transpose(0, 3, 2, 1), (t, n, qn), - order="F").transpose(0, 2, 1) # backward autocov - - A_f = np.zeros((t, n, qn)) # forward coefficients - A_b = np.zeros((t, n, qn)) # backward coefficients - - k = 1 # model order - r = q - k - k_f = np.arange(k * n) # forward indices - k_b = np.arange(r * n, qn) # backward indices - - try: - A_f[:, :, k_f] = np.linalg.solve( - cov, G_b[:, k_b, :].transpose(0, 2, 1)).transpose(0, 2, 1) - A_b[:, :, k_b] = np.linalg.solve( - cov, G_f[:, k_f, :].transpose(0, 2, 1)).transpose(0, 2, 1) - - # Perform recursion - for k in np.arange(2, q + 1): - var_A = (G_b[:, (r - 1) * n: r * n, :] - - np.matmul(A_f[:, :, k_f], G_b[:, k_b, :])) - var_B = cov - np.matmul(A_b[:, :, k_b], G_b[:, k_b, :]) - AA_f = np.linalg.solve( - var_B, var_A.transpose(0, 2, 1)).transpose(0, 2, 1) - - var_A = (G_f[:, (k - 1) * n: k * n, :] - - np.matmul(A_b[:, :, k_b], G_f[:, k_f, :])) - var_B = cov - np.matmul(A_f[:, :, k_f], G_f[:, k_f, :]) - AA_b = np.linalg.solve( - var_B, var_A.transpose(0, 2, 1)).transpose(0, 2, 1) - - A_f_previous = A_f[:, :, k_f] - A_b_previous = A_b[:, :, k_b] - - r = q - k - k_f = np.arange(k * n) - k_b = np.arange(r * n, qn) - - A_f[:, :, k_f] = np.dstack( - (A_f_previous - np.matmul(AA_f, A_b_previous), AA_f)) - A_b[:, :, k_b] = np.dstack( - (AA_b, A_b_previous - np.matmul(AA_b, A_f_previous))) - except np.linalg.LinAlgError as np_error: - raise RuntimeError( - 'the autocovariance matrix is singular; check if your data is ' - 'rank deficient and specify an appropriate rank argument <= ' - 'the rank of the seeds and targets') from np_error - - V = cov - np.matmul(A_f, G_f) - A_f = np.reshape(A_f, (t, n, n, q), order="F") - - return A_f, V - - def _full_var_to_iss(self, A_f): - """Compute innovations-form parameters for a state-space model. - - Parameters computed from a full VAR model using Aoki's method. For a - non-moving-average full VAR model, the state-space parameter C - (observation matrix) is identical to AF of the VAR model. - - See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: - 10.1103/PhysRevE.91.040101. - """ - t = A_f.shape[0] - m = A_f.shape[1] # number of signals - p = A_f.shape[2] // m # number of autoregressive lags - - I_p = np.dstack(t * [np.eye(m * p)]).transpose(2, 0, 1) - A = np.hstack((A_f, I_p[:, : (m * p - m), :])) # state transition - # matrix - K = np.hstack(( - np.dstack(t * [np.eye(m)]).transpose(2, 0, 1), - np.zeros((t, (m * (p - 1)), m)))) # Kalman gain matrix - - return A, K - - def _iss_to_ugc(self, A, C, K, V, seeds, targets): - """Compute unconditional GC from innovations-form state-space params. - - See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: - 10.1103/PhysRevE.91.040101. - """ - times = np.arange(A.shape[0]) - freqs = np.arange(self.n_freqs) - z = np.exp(-1j * np.pi * np.linspace(0, 1, self.n_freqs)) # points - # on a unit circle in the complex plane, one for each frequency - - H = self._iss_to_tf(A, C, K, z) # spectral transfer function - V_22_1 = np.linalg.cholesky(self._partial_covar(V, seeds, targets)) - HV = np.matmul(H, np.linalg.cholesky(V)) - S = np.matmul(HV, HV.conj().transpose(0, 1, 3, 2)) # Eq. 6 - S_11 = S[np.ix_(freqs, times, targets, targets)] - HV_12 = np.matmul(H[np.ix_(freqs, times, targets, seeds)], V_22_1) - HVH = np.matmul(HV_12, HV_12.conj().transpose(0, 1, 3, 2)) - - # Eq. 11 - return np.real( - np.log(np.linalg.det(S_11)) - np.log(np.linalg.det(S_11 - HVH))) - - def _iss_to_tf(self, A, C, K, z): - """Compute transfer function for innovations-form state-space params. - - In the frequency domain, the back-shift operator, z, is a vector of - points on a unit circle in the complex plane. z = e^-iw, where -pi < w - <= pi. - - A note on efficiency: solving over the 4D time-freq. tensor is slower - than looping over times and freqs when n_times and n_freqs high, and - when n_times and n_freqs low, looping over times and freqs very fast - anyway (plus tensor solving doesn't allow for parallelisation). - - See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: - 10.1103/PhysRevE.91.040101. - """ - t = A.shape[0] - h = self.n_freqs - n = C.shape[1] - m = A.shape[1] - I_n = np.eye(n) - I_m = np.eye(m) - H = np.zeros((h, t, n, n), dtype=np.complex128) - - parallel, parallel_compute_H, _ = parallel_func( - _gc_compute_H, self.n_jobs, verbose=False - ) - H = np.zeros((h, t, n, n), dtype=np.complex128) - for block_i in ProgressBar( - range(self.n_steps), mesg="frequency blocks" - ): - freqs = self._get_block_indices(block_i, self.n_freqs) - H[freqs] = parallel( - parallel_compute_H(A, C, K, z[k], I_n, I_m) for k in freqs) - - return H - - def _partial_covar(self, V, seeds, targets): - """Compute partial covariance of a matrix. - - Given a covariance matrix V, the partial covariance matrix of V between - indices i and j, given k (V_ij|k), is equivalent to V_ij - V_ik * - V_kk^-1 * V_kj. In this case, i and j are seeds, and k are targets. - - See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: - 10.1103/PhysRevE.91.040101. - """ - times = np.arange(V.shape[0]) - W = np.linalg.solve( - np.linalg.cholesky(V[np.ix_(times, targets, targets)]), - V[np.ix_(times, targets, seeds)], - ) - W = np.matmul(W.transpose(0, 2, 1), W) - - return V[np.ix_(times, seeds, seeds)] - W - - def reshape_results(self): - """Remove time dimension from con. scores, if necessary.""" - if self.n_times == 0: - self.con_scores = self.con_scores[:, :, 0] - - -def _gc_compute_H(A, C, K, z_k, I_n, I_m): - """Compute transfer function for innovations-form state-space params. - - See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: - 10.1103/PhysRevE.91.040101, Eq. 4. - """ - from scipy import linalg # XXX: is this necessary??? - H = np.zeros((A.shape[0], C.shape[1], C.shape[1]), dtype=np.complex128) - for t in range(A.shape[0]): - H[t] = I_n + np.matmul( - C[t], linalg.lu_solve(linalg.lu_factor(z_k * I_m - A[t]), K[t])) - - return H - - -class _GCEst(_GCEstBase): - """[seeds -> targets] state-space GC estimator.""" - - name = "GC" - - -class _GCTREst(_GCEstBase): - """time-reversed[seeds -> targets] state-space GC estimator.""" - - name = "GC time-reversed" - -############################################################################### - - -# map names to estimator types -_CON_METHOD_MAP = {'mic': _MICEst, 'mim': _MIMEst, 'gc': _GCEst, - 'gc_tr': _GCTREst} - -_gc_methods = ['gc', 'gc_tr'] - - -@ verbose -@ fill_doc -def spectral_connectivity_epochs_multivariate( - data, names=None, method='mic', indices=None, sfreq=None, - mode='multitaper', fmin=None, fmax=np.inf, fskip=0, faverage=False, - tmin=None, tmax=None, mt_bandwidth=None, mt_adaptive=False, - mt_low_bias=True, cwt_freqs=None, cwt_n_cycles=7, gc_n_lags=40, rank=None, - block_size=1000, n_jobs=1, verbose=None -): - r"""Compute multivariate (time-)frequency-domain connectivity measures. - - The connectivity method(s) are specified using the "method" parameter. - All methods are based on estimates of the cross-spectral density (CSD). - - Parameters - ---------- - data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs - The data from which to compute connectivity. Note that it is also - possible to combine multiple signals by providing a list of tuples, - e.g., data = [(arr_0, stc_0), (arr_1, stc_1), (arr_2, stc_2)], - corresponds to 3 epochs, and arr_* could be an array with the same - number of time points as stc_*. The array-like object can also - be a list/generator of array, shape =(n_signals, n_times), - or a list/generator of SourceEstimate or VolSourceEstimate objects. - %(names)s - method : str | list of str - Connectivity measure(s) to compute. These can be ``['mic', 'mim', 'gc', - 'gc_tr']``. - indices : tuple of array | None - Two arrays with indices of connections for which to compute - connectivity. Each array for the seeds and targets should consist of - nested arrays containing the channel indices for each multivariate - connection. If ``None``, connections between all channels are computed, - unless a Granger causality method is called, in which case an error is - raised. - sfreq : float - The sampling frequency. Required if data is not - :class:`Epochs `. - mode : str - Spectrum estimation mode can be either: 'multitaper', 'fourier', or - 'cwt_morlet'. - fmin : float | tuple of float - The lower frequency of interest. Multiple bands are defined using - a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. - fmax : float | tuple of float - The upper frequency of interest. Multiple bands are dedined using - a tuple, e.g. (13., 30.) for two band with 13Hz and 30Hz upper freq. - fskip : int - Omit every "(fskip + 1)-th" frequency bin to decimate in frequency - domain. - faverage : bool - Average connectivity scores for each frequency band. If True, - the output freqs will be a list with arrays of the frequencies - that were averaged. - tmin : float | None - Time to start connectivity estimation. Note: when "data" is an array, - the first sample is assumed to be at time 0. For other types - (Epochs, etc.), the time information contained in the object is used - to compute the time indices. - tmax : float | None - Time to end connectivity estimation. Note: when "data" is an array, - the first sample is assumed to be at time 0. For other types - (Epochs, etc.), the time information contained in the object is used - to compute the time indices. - mt_bandwidth : float | None - The bandwidth of the multitaper windowing function in Hz. - Only used in 'multitaper' mode. - mt_adaptive : bool - Use adaptive weights to combine the tapered spectra into PSD. - Only used in 'multitaper' mode. - mt_low_bias : bool - Only use tapers with more than 90 percent spectral concentration - within bandwidth. Only used in 'multitaper' mode. - cwt_freqs : array - Array of frequencies of interest. Only used in 'cwt_morlet' mode. - cwt_n_cycles : float | array of float - Number of cycles. Fixed number or one per frequency. Only used in - 'cwt_morlet' mode. - gc_n_lags : int - Number of lags to use when computing Granger causality (the vector - autoregressive model order). Higher values increase computational cost, - but reduce the degree of spectral smoothing in the results. Must be < - (n_freqs - 1) * 2. Only used if ``method`` contains any of ``['gc', - 'gc_tr']``. - rank : tuple of array | None - Two arrays with the rank to project the seed and target data to, - respectively, using singular value decomposition. If None, the rank of - the data is computed and projected to. Only used if ``method`` contains - any of ``['mic', 'mim', 'gc', 'gc_tr']``. - block_size : int - How many CSD entries to compute at once (higher numbers are faster but - require more memory). - n_jobs : int - How many samples to process in parallel. - %(verbose)s - - Returns - ------- - con : array | list of array - Computed connectivity measure(s). Either an instance of - ``SpectralConnectivity`` or ``SpectroTemporalConnectivity``. - The shape of the connectivity result will be: - - - ``(n_cons, n_freqs)`` for multitaper or fourier modes - - ``(n_cons, n_freqs, n_times)`` for cwt_morlet mode - - ``n_cons = 1`` when ``indices=None`` - - ``n_cons = len(indices[0])`` when indices is supplied - - See Also - -------- - mne_connectivity.spectral_connectivity_epochs - mne_connectivity.spectral_connectivity_time - mne_connectivity.SpectralConnectivity - mne_connectivity.SpectroTemporalConnectivity - - Notes - ----- - Please note that the interpretation of the measures in this function - depends on the data and underlying assumptions and does not necessarily - reflect a causal relationship between brain regions. - - These measures are not to be interpreted over time. Each Epoch passed into - the dataset is interpreted as an independent sample of the same - connectivity structure. Within each Epoch, it is assumed that the spectral - measure is stationary. The spectral measures implemented in this function - are computed across Epochs. **Thus, spectral measures computed with only - one Epoch will result in errorful values and spectral measures computed - with few Epochs will be unreliable.** Please see - ``spectral_connectivity_time`` for time-resolved connectivity estimation. - - The spectral densities can be estimated using a multitaper method with - digital prolate spheroidal sequence (DPSS) windows, a discrete Fourier - transform with Hanning windows, or a continuous wavelet transform using - Morlet wavelets. The spectral estimation mode is specified using the - "mode" parameter. - - By default, "indices" is None, and the connectivity between all signals is - computed and a single connectivity spectrum will be returned (this is not - possible if a Granger causality method is called). If one is only - interested in the connectivity between some signals, the "indices" - parameter can be used. Seed and target indices for each connection should - be specified as nested array-likes. For example, to compute the - connectivity between signals (0, 1) -> (2, 3) and (0, 1) -> (4, 5), indices - should be specified as:: - - indices = ([[0, 1], [0, 1]], # seeds - [[2, 3], [4, 5]]) # targets - - More information on working with multivariate indices and handling - connections where the number of seeds and targets are not equal can be - found in the :doc:`../auto_examples/handling_ragged_arrays` example. - - **Supported Connectivity Measures** - - The connectivity method(s) is specified using the "method" parameter. - Multiple measures can be computed at once by using a list/tuple, e.g., - ``['mic', 'gc']``. The following methods are supported: - - 'mic' : Maximised Imaginary part of Coherency (MIC) - :footcite:`EwaldEtAl2012` given by: - - :math:`MIC=\Large{\frac{\boldsymbol{\alpha}^T \boldsymbol{E \beta}} - {\parallel\boldsymbol{\alpha}\parallel \parallel\boldsymbol{\beta} - \parallel}}` - - where: :math:`\boldsymbol{E}` is the imaginary part of the - transformed cross-spectral density between seeds and targets; and - :math:`\boldsymbol{\alpha}` and :math:`\boldsymbol{\beta}` are - eigenvectors for the seeds and targets, such that - :math:`\boldsymbol{\alpha}^T \boldsymbol{E \beta}` maximises - connectivity between the seeds and targets. - - 'mim' : Multivariate Interaction Measure (MIM) - :footcite:`EwaldEtAl2012` given by: - - :math:`MIM=tr(\boldsymbol{EE}^T)` - - 'gc' : State-space Granger Causality (GC) :footcite:`BarnettSeth2015` - given by: - - :math:`GC = ln\Large{(\frac{\lvert\boldsymbol{S}_{tt}\rvert}{\lvert - \boldsymbol{S}_{tt}-\boldsymbol{H}_{ts}\boldsymbol{\Sigma}_{ss - \lvert t}\boldsymbol{H}_{ts}^*\rvert}})`, - - where: :math:`s` and :math:`t` represent the seeds and targets, - respectively; :math:`\boldsymbol{H}` is the spectral transfer - function; :math:`\boldsymbol{\Sigma}` is the residuals matrix of - the autoregressive model; and :math:`\boldsymbol{S}` is - :math:`\boldsymbol{\Sigma}` transformed by :math:`\boldsymbol{H}`. - - 'gc_tr' : State-space GC on time-reversed signals - :footcite:`BarnettSeth2015,WinklerEtAl2016` given by the same equation - as for 'gc', but where the autocovariance sequence from which the - autoregressive model is produced is transposed to mimic the reversal of - the original signal in time. - - References - ---------- - .. footbibliography:: - """ - ( - fmin, fmax, n_bands, method, con_method_types, accumulate_psd, - parallel, my_epoch_spectral_connectivity - ) = _check_spectral_connectivity_epochs_settings( - method, fmin, fmax, n_jobs, verbose, _CON_METHOD_MAP) - - if n_bands != 1 and any( - this_method in _gc_methods for this_method in method - ): - raise ValueError('computing Granger causality on multiple frequency ' - 'bands is not yet supported') - - (names, times_in, sfreq, events, event_id, - metadata) = _check_spectral_connectivity_epochs_data(data, sfreq, names) - - # loop over data; it could be a generator that returns - # (n_signals x n_times) arrays or SourceEstimates - epoch_idx = 0 - logger.info('Connectivity computation...') - warn_times = True - for epoch_block in _get_n_epochs(data, n_jobs): - if epoch_idx == 0: - # initialize everything times and frequencies - (times, n_times, times_in, n_times_in, tmin_idx, tmax_idx, n_freqs, - freq_mask, freqs, freqs_bands, freq_idx_bands, n_signals, - warn_times) = _prepare_connectivity( - epoch_block=epoch_block, times_in=times_in, tmin=tmin, - tmax=tmax, fmin=fmin, fmax=fmax, sfreq=sfreq, mode=mode, - fskip=fskip, n_bands=n_bands, cwt_freqs=cwt_freqs, - faverage=faverage) - - # check indices input - n_cons, indices_use = _check_indices(indices, method, n_signals) - - # check rank input and compute data ranks - rank = _check_rank_input(rank, data, indices_use) - - # make sure padded indices are stored in the connectivity object - if indices is not None: - indices = tuple(np.array(indices_use)) # create a copy - - # get the window function, wavelets, etc for different modes - (spectral_params, mt_adaptive, n_times_spectrum, - n_tapers) = _assemble_spectral_params( - mode=mode, n_times=n_times, mt_adaptive=mt_adaptive, - mt_bandwidth=mt_bandwidth, sfreq=sfreq, - mt_low_bias=mt_low_bias, cwt_n_cycles=cwt_n_cycles, - cwt_freqs=cwt_freqs, freqs=freqs, freq_mask=freq_mask) - - # unique signals for which we actually need to compute CSD - sig_idx = np.unique(np.concatenate(np.concatenate( - indices_use))) - sig_idx = sig_idx[sig_idx != -1] - remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(sig_idx)} - remapping[-1] = -1 - remapped_inds = (indices_use[0].copy(), indices_use[1].copy()) - con_i = 0 - for seed, target in zip(indices_use[0], indices_use[1]): - remapped_inds[0][con_i] = np.array([ - remapping[idx] for idx in seed]) - remapped_inds[1][con_i] = np.array([ - remapping[idx] for idx in target]) - con_i += 1 - remapped_sig = [remapping[idx] for idx in sig_idx] - n_signals_use = len(sig_idx) - - # map indices to unique indices - indices_use = remapped_inds # use remapped seeds & targets - idx_map = [np.sort(np.repeat(remapped_sig, len(sig_idx))), - np.tile(remapped_sig, len(sig_idx))] - - # create instances of the connectivity estimators - con_methods = [] - for mtype_i, mtype in enumerate(con_method_types): - method_params = dict(n_cons=n_cons, n_freqs=n_freqs, - n_times=n_times_spectrum, - n_signals=n_signals_use) - if method[mtype_i] in _gc_methods: - method_params.update(dict(n_lags=gc_n_lags)) - con_methods.append(mtype(**method_params)) - - sep = ', ' - metrics_str = sep.join([meth.name for meth in con_methods]) - logger.info(' the following metrics will be computed: %s' - % metrics_str) - - call_params = dict( - sig_idx=sig_idx, tmin_idx=tmin_idx, tmax_idx=tmax_idx, sfreq=sfreq, - method=method, mode=mode, freq_mask=freq_mask, idx_map=idx_map, - n_cons=n_cons, block_size=block_size, - psd=None, accumulate_psd=accumulate_psd, - mt_adaptive=mt_adaptive, - con_method_types=con_method_types, - con_methods=con_methods if n_jobs == 1 else None, - n_signals=n_signals, n_signals_use=n_signals_use, n_times=n_times, - gc_n_lags=gc_n_lags, multivariate_con=True, - accumulate_inplace=True if n_jobs == 1 else False) - call_params.update(**spectral_params) - - epoch_idx = _compute_spectral_methods_epochs( - con_methods, epoch_block, epoch_idx, call_params, parallel, - my_epoch_spectral_connectivity, n_jobs, n_times_in, times_in, - warn_times) - n_epochs = epoch_idx - - # compute final connectivity scores - con = list() - patterns = list() - for conn_method in con_methods: - - # compute connectivity scores - conn_method.compute_con(indices_use, rank, n_epochs) - - # get the connectivity scores - this_con = conn_method.con_scores - this_patterns = conn_method.patterns - - if this_con.shape[0] != n_cons: - raise RuntimeError( - 'first dimension of connectivity scores does not match the ' - 'number of connections; please contact the mne-connectivity ' - 'developers') - if faverage: - if this_con.shape[1] != n_freqs: - raise RuntimeError( - 'second dimension of connectivity scores does not match ' - 'the number of frequencies; please contact the ' - 'mne-connectivity developers') - con_shape = (n_cons, n_bands) + this_con.shape[2:] - this_con_bands = np.empty(con_shape, dtype=this_con.dtype) - for band_idx in range(n_bands): - this_con_bands[:, band_idx] = np.mean( - this_con[:, freq_idx_bands[band_idx]], axis=1) - this_con = this_con_bands - - if this_patterns is not None: - patterns_shape = list(this_patterns.shape) - patterns_shape[3] = n_bands - this_patterns_bands = np.empty(patterns_shape, - dtype=this_patterns.dtype) - for band_idx in range(n_bands): - this_patterns_bands[:, :, :, band_idx] = np.mean( - this_patterns[:, :, :, freq_idx_bands[band_idx]], - axis=3) - this_patterns = this_patterns_bands - - con.append(this_con) - patterns.append(this_patterns) - - conn_list = _store_results( - con=con, patterns=patterns, method=method, freqs=freqs, - faverage=faverage, freqs_bands=freqs_bands, names=names, mode=mode, - indices=indices, n_epochs=n_epochs, times=times, n_tapers=n_tapers, - metadata=metadata, events=events, event_id=event_id, rank=rank, - gc_n_lags=gc_n_lags, n_signals=n_signals) - - return conn_list diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 54bfafa5..592291f0 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -9,11 +9,10 @@ from mne_connectivity import ( SpectralConnectivity, spectral_connectivity_epochs, - spectral_connectivity_epochs_multivariate, read_connectivity, spectral_connectivity_time) +from mne_connectivity.spectral.epochs import _CohEst, _get_n_epochs from mne_connectivity.spectral.epochs import ( - _get_n_epochs, _compute_freq_mask, _compute_freqs) -from mne_connectivity.spectral.epochs_bivariate import _CohEst + _compute_freq_mask, _compute_freqs) def create_test_dataset(sfreq, n_signals, n_epochs, n_times, tmin, tmax, @@ -448,7 +447,7 @@ def test_spectral_connectivity_epochs_multivariate(method): data = data.reshape(n_signals, n_epochs, n_times) data = np.transpose(data, [1, 0, 2]) - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, gc_n_lags=20) freqs = con.freqs @@ -474,17 +473,17 @@ def test_spectral_connectivity_epochs_multivariate(method): # check that target -> seed connectivity is low indices_ts = (indices[1], indices[0]) - con_ts = spectral_connectivity_epochs_multivariate( + con_ts = spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices_ts, sfreq=sfreq, gc_n_lags=20) assert con_ts.get_data()[0, gidx[0]:gidx[1]].mean() < lower_t # check that TRGC is positive (i.e. net seed -> target connectivity not # due to noise) - con_tr = spectral_connectivity_epochs_multivariate( + con_tr = spectral_connectivity_epochs( data, method='gc_tr', mode=mode, indices=indices, sfreq=sfreq, gc_n_lags=20) - con_ts_tr = spectral_connectivity_epochs_multivariate( + con_ts_tr = spectral_connectivity_epochs( data, method='gc_tr', mode=mode, indices=indices_ts, sfreq=sfreq, gc_n_lags=20) trgc = ((con.get_data() - con_ts.get_data()) - @@ -498,7 +497,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check all-to-all conn. computed for MIC/MIM when no indices given if method in ['mic', 'mim']: - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=None, sfreq=sfreq) assert con.indices is None assert con.n_nodes == n_signals @@ -507,7 +506,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check ragged indices padded correctly ragged_indices = (np.array([[0]]), np.array([[1, 2]])) - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=ragged_indices, sfreq=sfreq) assert np.all(np.array(con.indices) == np.array([np.array([[0, -1]]), np.array([[1, 2]])])) @@ -515,7 +514,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check shape of MIC patterns if method == 'mic': for mode in ['multitaper', 'cwt_morlet']: - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, fmin=10, fmax=25, cwt_freqs=np.arange(10, 25), faverage=True) @@ -536,7 +535,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check patterns averaged over freqs fmin = (5., 15.) fmax = (15., 30.) - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, fmin=fmin, fmax=fmax, faverage=True) assert np.shape(con.attrs["patterns"][0][0])[1] == len(fmin) @@ -544,7 +543,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check patterns shape matches input data, not rank rank = (np.array([1]), np.array([1])) - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=rank) assert (np.shape(con.attrs["patterns"][0][0])[0] == n_seeds) @@ -552,7 +551,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # check patterns padded correctly ragged_indices = (np.array([[0]]), np.array([[1, 2]])) - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=ragged_indices, sfreq=sfreq) patterns = np.array(con.attrs["patterns"]) @@ -587,7 +586,7 @@ def test_multivariate_spectral_connectivity_epochs_regression(): sfreq = 100 indices = (np.array([[0, 1]]), np.array([[2, 3]])) methods = ['mic', 'mim', 'gc', 'gc_tr'] - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( data, method=methods, indices=indices, mode='multitaper', sfreq=sfreq, fskip=0, faverage=False, tmin=0, tmax=None, mt_bandwidth=4, mt_low_bias=True, mt_adaptive=False, gc_n_lags=20, @@ -595,9 +594,8 @@ def test_multivariate_spectral_connectivity_epochs_regression(): # should take the absolute of the MIC scores, as the MATLAB implementation # returns the absolute values. - mne_results = {this_con.method: this_con.get_data() for this_con in con} - mne_results["mic"] = np.abs(mne_results["mic"]) - + mne_results = {this_con.method: np.abs(this_con.get_data()) + for this_con in con} matlab_results = pd.read_pickle( os.path.join(fpath, 'data', 'example_multivariate_matlab_results.pkl')) for method in methods: @@ -622,29 +620,40 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): with pytest.raises(TypeError, match='multivariate indices must contain array-likes'): non_nested_indices = (np.array([0, 1]), np.array([2, 3])) - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( data, method=method, mode=mode, indices=non_nested_indices, - sfreq=sfreq, cwt_freqs=cwt_freqs, gc_n_lags=10) + sfreq=sfreq, gc_n_lags=10) # check bad indices with repeated channels caught with pytest.raises(ValueError, match='multivariate indices cannot contain repeated'): repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( data, method=method, mode=mode, indices=repeated_indices, - sfreq=sfreq, cwt_freqs=cwt_freqs, gc_n_lags=10) + sfreq=sfreq, gc_n_lags=10) + + # check mixed methods caught + with pytest.raises(ValueError, + match='bivariate and multivariate connectivity'): + if isinstance(method, str): + mixed_methods = [method, 'coh'] + elif isinstance(method, list): + mixed_methods = [*method, 'coh'] + spectral_connectivity_epochs(data, method=mixed_methods, mode=mode, + indices=indices, sfreq=sfreq, + cwt_freqs=cwt_freqs) # check bad rank args caught too_low_rank = (np.array([0]), np.array([0])) with pytest.raises(ValueError, match='ranks for seeds and targets must be'): - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=too_low_rank, cwt_freqs=cwt_freqs) too_high_rank = (np.array([3]), np.array([3])) with pytest.raises(ValueError, match='ranks for seeds and targets must be'): - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=too_high_rank, cwt_freqs=cwt_freqs) @@ -655,7 +664,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): assert np.all(np.linalg.matrix_rank(bad_data[:, (0, 1), :]) == 1) assert np.all(np.linalg.matrix_rank(bad_data[:, (2, 3), :]) == 1) if isinstance(method, str): - rank_con = spectral_connectivity_epochs_multivariate( + rank_con = spectral_connectivity_epochs( bad_data, method=method, mode=mode, indices=indices, sfreq=sfreq, gc_n_lags=10, cwt_freqs=cwt_freqs) assert rank_con.attrs["rank"] == (np.array([1]), np.array([1])) @@ -664,7 +673,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): # check rank-deficient transformation matrix caught with pytest.raises(RuntimeError, match='the transformation matrix'): - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( bad_data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=(np.array([2]), np.array([2])), cwt_freqs=cwt_freqs) @@ -675,36 +684,37 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): frange = (5, 10) n_lags = 200 # will be far too high with pytest.raises(ValueError, match='the number of lags'): - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, fmin=frange[0], fmax=frange[1], gc_n_lags=n_lags, cwt_freqs=cwt_freqs) # check no indices caught with pytest.raises(ValueError, match='indices must be specified'): - spectral_connectivity_epochs_multivariate( - data, method=method, mode=mode, indices=None, sfreq=sfreq, - cwt_freqs=cwt_freqs) + spectral_connectivity_epochs(data, method=method, mode=mode, + indices=None, sfreq=sfreq, + cwt_freqs=cwt_freqs) # check intersecting indices caught bad_indices = (np.array([[0, 1]]), np.array([[0, 2]])) with pytest.raises(ValueError, match='seed and target indices must not intersect'): - spectral_connectivity_epochs_multivariate( - data, method=method, mode=mode, indices=bad_indices, - sfreq=sfreq, cwt_freqs=cwt_freqs) + spectral_connectivity_epochs(data, method=method, mode=mode, + indices=bad_indices, sfreq=sfreq, + cwt_freqs=cwt_freqs) # check bad fmin/fmax caught with pytest.raises(ValueError, match='computing Granger causality on multiple'): - spectral_connectivity_epochs_multivariate( - data, method=method, mode=mode, indices=indices, sfreq=sfreq, - fmin=(10., 15.), fmax=(15., 20.), cwt_freqs=cwt_freqs) + spectral_connectivity_epochs(data, method=method, mode=mode, + indices=indices, sfreq=sfreq, + fmin=(10., 15.), fmax=(15., 20.), + cwt_freqs=cwt_freqs) # check rank-deficient autocovariance caught with pytest.raises(RuntimeError, match='the autocovariance matrix is singular'): - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( bad_data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=(np.array([2]), np.array([2])), cwt_freqs=cwt_freqs) @@ -721,7 +731,7 @@ def test_multivar_spectral_connectivity_parallel(method): data = rng.randn(n_epochs, n_signals, n_times) indices = (np.array([[0, 1]]), np.array([[2, 3]])) - spectral_connectivity_epochs_multivariate( + spectral_connectivity_epochs( data, method=method, mode="multitaper", indices=indices, sfreq=sfreq, gc_n_lags=10, n_jobs=2) spectral_connectivity_time( @@ -751,13 +761,12 @@ def test_multivar_spectral_connectivity_flipped_indices(): # we test on GC since this is a directed connectivity measure method = 'gc' - con_st = spectral_connectivity_epochs_multivariate( # seed -> target + con_st = spectral_connectivity_epochs( # seed -> target data, method=method, indices=indices, sfreq=sfreq, gc_n_lags=10) - con_ts = spectral_connectivity_epochs_multivariate( # target -> seed + con_ts = spectral_connectivity_epochs( # target -> seed data, method=method, indices=flipped_indices, sfreq=sfreq, gc_n_lags=10) - con_st_ts = spectral_connectivity_epochs_multivariate( - # seed -> target; target -> seed + con_st_ts = spectral_connectivity_epochs( # seed -> target; target -> seed data, method=method, indices=concat_indices, sfreq=sfreq, gc_n_lags=10) assert not np.all(con_st.get_data() == con_ts.get_data()) assert np.all(con_st.get_data()[0] == con_st_ts.get_data()[0]) @@ -1289,7 +1298,7 @@ def test_multivar_save_load(tmp_path): non_ragged_indices = (np.array([[0, 1]]), np.array([[2, 3]])) ragged_indices = (np.array([[0, 1]]), np.array([[2]])) for indices in [non_ragged_indices, ragged_indices]: - con = spectral_connectivity_epochs_multivariate( + con = spectral_connectivity_epochs( epochs, method=['mic', 'mim', 'gc', 'gc_tr'], indices=indices, sfreq=sfreq, fmin=10, fmax=30) for this_con in con: @@ -1306,9 +1315,12 @@ def test_multivar_save_load(tmp_path): assert a == b -@pytest.mark.parametrize("method", ['coh', 'plv', 'pli', 'wpli', 'ciplv']) +@pytest.mark.parametrize("method", ["coh", "plv", "pli", "wpli", "ciplv", + "mic", "mim"]) @pytest.mark.parametrize("indices", [None, - (np.array([0, 1]), np.array([2, 3]))]) + (np.array([0, 1]), np.array([2, 3])), + (np.array([[0, 1]]), np.array([[2, 3]])) + ]) def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): """Test that indices values and type is maintained after saving. @@ -1325,6 +1337,14 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): freqs = np.arange(10, 31) tmp_file = os.path.join(tmp_path, "foo_mvc.nc") + # mutlivariate and bivariate methods require the right indices shape + if method in ["mic", "mim"]: + if indices is not None and indices[0].ndim == 1: + pytest.skip() + else: + if indices is not None and indices[0].ndim == 2: + pytest.skip() + # test the pair of method and indices defined to check the output indices con_epochs = spectral_connectivity_epochs( epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30 @@ -1346,53 +1366,3 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): assert np.all(np.array(con.indices) == np.array(read_con.indices)) else: assert con.indices is None and read_con.indices is None - - -@pytest.mark.parametrize("method", ['mic', 'mim', 'gc', 'gc_tr']) -@pytest.mark.parametrize("indices", [None, - (np.array([[0, 1]]), np.array([[2, 3]]))]) -def test_multivar_spectral_connectivity_indices_roundtrip_io( - tmp_path, method, indices -): - """Test that indices values and type is maintained after saving. - - If `indices` is None, `indices` in the returned connectivity object should - be None, otherwise, `indices` should be a tuple. The type of `indices` and - its values should be retained after saving and reloading. - """ - rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq = 5, 4, 200, 100.0 - data = rng.randn(n_epochs, n_chs, n_times) - info = create_info(n_chs, sfreq, "eeg") - tmin = -1 - epochs = EpochsArray(data, info, tmin=tmin) - freqs = np.arange(10, 31) - tmp_file = os.path.join(tmp_path, "foo_mvc.nc") - - # test the pair of method and indices defined to check the output indices - if indices is None and method in ['gc', 'gc_tr']: - # indicesmust be specified for GC - pytest.skip() - - con_epochs = spectral_connectivity_epochs_multivariate( - epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30, - gc_n_lags=10 - ) - con_time = spectral_connectivity_time( - epochs, freqs, method=method, indices=indices, sfreq=sfreq, - gc_n_lags=10 - ) - - for con in [con_epochs, con_time]: - con.save(tmp_file) - read_con = read_connectivity(tmp_file) - - if indices is not None: - # check indices of same type (tuples) - assert isinstance(con.indices, tuple) and isinstance( - read_con.indices, tuple - ) - # check indices have same values - assert np.all(np.array(con.indices) == np.array(read_con.indices)) - else: - assert con.indices is None and read_con.indices is None diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index d0059ace..3798f699 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -13,9 +13,8 @@ from mne.utils import (logger, verbose) from ..base import (SpectralConnectivity, EpochSpectralConnectivity) -from .epochs import _compute_freq_mask -from .epochs_multivariate import (_MICEst, _MIMEst, _GCEst, _GCTREst, - _check_rank_input) +from .epochs import (_MICEst, _MIMEst, _GCEst, _GCTREst, _compute_freq_mask, + _check_rank_input) from .smooth import _create_kernel, _smooth_spectra from ..utils import check_indices, check_multivariate_indices, fill_doc From fd619494fb4fe6d4b8e1f62754cd8438b2f20d23 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 2 Nov 2023 15:57:51 +0100 Subject: [PATCH 34/40] updated time --- mne_connectivity/spectral/time.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 6258752d..340779ef 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -407,8 +407,10 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, 'indices must be specified when computing Granger ' 'causality, as all-to-all connectivity is not supported') logger.info('using all indices for multivariate connectivity') - indices_use = (np.array([np.arange(n_signals, dtype=np.int32)]), - np.array([np.arange(n_signals, dtype=np.int32)])) + indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], + np.arange(n_signals, dtype=int)[np.newaxis, :]) + indices_use = np.ma.masked_array(indices_use, + mask=False, fill_value=-1) else: logger.info('only using indices for lower-triangular matrix') indices_use = np.tril_indices(n_signals, k=-1) @@ -437,20 +439,20 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, # unique signals for which we actually need to compute the CSD of if multivariate_con: - signals_use = np.unique(np.concatenate(np.concatenate(indices_use))) - signals_use = signals_use[signals_use != -1] + signals_use = np.unique(indices_use.compressed()) remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(signals_use)} - remapping[-1] = -1 + remapped_inds = indices_use.copy() # multivariate functions expect seed/target remapping - con_i = 0 - for seed, target in zip(indices_use[0], indices_use[1]): - source_idx[con_i] = np.array([remapping[idx] for idx in seed]) - target_idx[con_i] = np.array([remapping[idx] for idx in target]) - con_i += 1 + for idx in signals_use: + remapped_inds[indices_use == idx] = remapping[idx] + source_idx = remapped_inds[0] + target_idx = remapped_inds[1] max_n_channels = len(indices_use[0][0]) else: # no indices remapping required for bivariate functions signals_use = np.unique(np.r_[indices_use[0], indices_use[1]]) + source_idx = indices_use[0].copy() + target_idx = indices_use[1].copy() max_n_channels = len(indices_use[0]) # check rank input and compute data ranks if necessary From 900f7251b39c3a40050e5071a03a0fc3186a63a7 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 2 Nov 2023 16:13:43 +0100 Subject: [PATCH 35/40] switched to masked indices for multivariate conn --- examples/handling_ragged_arrays.py | 39 +++---- mne_connectivity/io.py | 6 +- .../spectral/tests/test_spectral.py | 15 +-- mne_connectivity/spectral/time.py | 15 ++- mne_connectivity/tests/test_utils.py | 16 ++- mne_connectivity/utils/utils.py | 105 +++++++++--------- 6 files changed, 94 insertions(+), 102 deletions(-) diff --git a/examples/handling_ragged_arrays.py b/examples/handling_ragged_arrays.py index 6ab01791..8076b9e2 100644 --- a/examples/handling_ragged_arrays.py +++ b/examples/handling_ragged_arrays.py @@ -62,8 +62,8 @@ # ragged_indices = (np.array([[0, 1 ], [0, 1, 2, 3]], dtype='object'), # np.array([[2, 3, 4], [4 ]], dtype='object')) # -# **N.B. Note that when forming ragged arrays in NumPy, dtype='object' must be -# specified.** +# **N.B. Note that since NumPy v1.19.0, dtype='object' must be specified when +# forming ragged arrays.** # # Just as for bivariate connectivity, the length of ``indices[0]`` and # ``indices[1]`` is equal (i.e. the number of connections), however information @@ -71,19 +71,16 @@ # array. Importantly, these indices are ragged, as the first connection will be # computed between 2 seed and 3 target channels, and the second connection # between 4 seed and 1 target channel. The connectivity functions will -# recognise the indices as being ragged, and pad them accordingly to make them -# easier to work with and compatible with the h5netcdf saving engine. The known -# value used to pad the arrays is ``-1``, an invalid channel index. The above -# indices would be padded to:: +# recognise the indices as being ragged, and pad them to a 'full' array by +# adding placeholder values which are masked accordingly. This makes the +# indices easier to work with, and also compatible with the engine used to save +# connectivity objects. For example, the above indices would become:: # -# padded_indices = (np.array([[0, 1, -1, -1], [0, 1, 2, 3]]), -# np.array([[2, 3, 4, -1], [4, -1, -1, -1]])) +# padded_indices = (np.array([[0, 1, --, --], [0, 1, 2, 3]]), +# np.array([[2, 3, 4, --], [4, --, --, --]])) # -# These indices are what is stored in the connectivity object, and is also the -# format of indices returned from the helper functions -# :func:`~mne_connectivity.check_multivariate_indices` and -# :func:`~mne_connectivity.seed_target_multivariate_indices`. It is also -# possible to pass the padded indices to the connectivity functions directly. +# where ``--`` are masked entries. These indices are what is stored in the +# returned connectivity objects. # # For the connectivity results themselves, the methods available in # MNE-Connectivity combine information across the different channels into a @@ -118,11 +115,11 @@ max_n_chans = max( len(inds) for inds in ([*ragged_indices[0], *ragged_indices[1]])) -# show that the padded indices entries are all -1 -assert np.count_nonzero(padded_indices[0][0] == -1) == 2 # 2 padded channels -assert np.count_nonzero(padded_indices[1][0] == -1) == 1 # 1 padded channels -assert np.count_nonzero(padded_indices[0][1] == -1) == 0 # 0 padded channels -assert np.count_nonzero(padded_indices[1][1] == -1) == 3 # 3 padded channels +# show that the padded indices entries are masked +assert np.sum(padded_indices[0][0].mask) == 2 # 2 padded channels +assert np.sum(padded_indices[1][0].mask) == 1 # 1 padded channels +assert np.sum(padded_indices[0][1].mask) == 0 # 0 padded channels +assert np.sum(padded_indices[1][1].mask) == 3 # 3 padded channels # patterns have shape [seeds/targets x cons x max channels x freqs (x times)] assert patterns.shape == (2, n_cons, max_n_chans, n_freqs) @@ -137,11 +134,11 @@ seed_patterns_con1 = patterns[0, 0, :len(ragged_indices[0][0])] target_patterns_con1 = patterns[1, 0, :len(ragged_indices[1][0])] -# extract patterns for second connection using the padded indices (pad = -1) +# extract patterns for second connection using the padded, masked indices seed_patterns_con2 = ( - patterns[0, 1, :np.count_nonzero(padded_indices[0][1] != -1)]) + patterns[0, 1, :padded_indices[0][1].count()]) target_patterns_con2 = ( - patterns[1, 1, :np.count_nonzero(padded_indices[1][1] != -1)]) + patterns[1, 1, :padded_indices[1][1].count()]) # show that shapes of patterns are correct assert seed_patterns_con1.shape == (2, n_freqs) # channels (0, 1) diff --git a/mne_connectivity/io.py b/mne_connectivity/io.py index 63aa3501..e8d9b916 100644 --- a/mne_connectivity/io.py +++ b/mne_connectivity/io.py @@ -53,9 +53,11 @@ def _xarray_to_conn(array, cls_func): event_id = dict(zip(event_id_keys, event_id_vals)) array.attrs['event_id'] = event_id - # convert indices numpy arrays to a tuple of arrays + # convert indices numpy arrays to a tuple of masked arrays + # (only multivariate connectivity indices saved as arrays) if isinstance(array.attrs['indices'], np.ndarray): - array.attrs['indices'] = tuple(array.attrs['indices']) + array.attrs['indices'] = tuple( + np.ma.masked_values(array.attrs['indices'], -1)) # create the connectivity class conn = cls_func( diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 273fa2d5..a514382b 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -1315,12 +1315,9 @@ def test_multivar_save_load(tmp_path): assert a == b -@pytest.mark.parametrize("method", ["coh", "plv", "pli", "wpli", "ciplv", - "mic", "mim"]) +@pytest.mark.parametrize("method", ['coh', 'plv', 'pli', 'wpli', 'ciplv']) @pytest.mark.parametrize("indices", [None, - (np.array([0, 1]), np.array([2, 3])), - (np.array([[0, 1]]), np.array([[2, 3]])) - ]) + (np.array([0, 1]), np.array([2, 3]))]) def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): """Test that indices values and type is maintained after saving. @@ -1337,14 +1334,6 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): freqs = np.arange(10, 31) tmp_file = os.path.join(tmp_path, "foo_mvc.nc") - # mutlivariate and bivariate methods require the right indices shape - if method in ["mic", "mim"]: - if indices is not None and indices[0].ndim == 1: - pytest.skip() - else: - if indices is not None and indices[0].ndim == 2: - pytest.skip() - # test the pair of method and indices defined to check the output indices con_epochs = spectral_connectivity_epochs( epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30 diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 340779ef..4ac3f161 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -422,20 +422,19 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, indices_use]) np.ma.set_fill_value(indices_use, -1) # else 99999 after concat. if any(this_method in _gc_methods for this_method in method): - for seed, target in zip(indices[0], indices[1]): - intersection = np.intersect1d(seed, target) - if np.any(intersection != -1): # ignore padded entries + for seed, target in zip(indices_use[0], indices_use[1]): + intersection = np.intersect1d(seed.compressed(), + target.compressed()) + if intersection.size > 0: raise ValueError( 'seed and target indices must not intersect when ' 'computing Granger causality') # make sure padded indices are stored in the connectivity object - indices = tuple(np.array(indices_use)) # create a copy + # create a copy + indices = (indices_use[0].copy(), indices_use[1].copy()) else: indices_use = check_indices(indices) - # create copies of indices_use for independent manipulation - source_idx = np.array(indices_use[0]) - target_idx = np.array(indices_use[1]) - n_cons = len(source_idx) + n_cons = len(indices_use[0]) # unique signals for which we actually need to compute the CSD of if multivariate_con: diff --git a/mne_connectivity/tests/test_utils.py b/mne_connectivity/tests/test_utils.py index 31f27d0a..caead679 100644 --- a/mne_connectivity/tests/test_utils.py +++ b/mne_connectivity/tests/test_utils.py @@ -34,14 +34,22 @@ def test_seed_target_indices(): seeds = [[0, 1]] targets = [[2, 3], [3, 4]] indices = seed_target_multivariate_indices(seeds, targets) - assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]), - np.array([[2, 3], [3, 4]]))) + match_indices = (np.array([[0, 1], [0, 1]], dtype=object), + np.array([[2, 3], [3, 4]], dtype=object)) + for type_i in range(2): + for con_i in range(len(indices[0])): + assert np.all(indices[type_i][con_i] == + match_indices[type_i][con_i]) # ragged indices seeds = [[0, 1]] targets = [[2, 3, 4], [4]] indices = seed_target_multivariate_indices(seeds, targets) - assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]), - np.array([[2, 3, 4], [4, -1, -1]]))) + match_indices = (np.array([[0, 1], [0, 1]], dtype=object), + np.array([[2, 3, 4], [4]], dtype=object)) + for type_i in range(2): + for con_i in range(len(indices[0])): + assert np.all(indices[type_i][con_i] == + match_indices[type_i][con_i]) # test error catching # non-array-like seeds/targets with pytest.raises(TypeError, diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index 6696a130..2cab0007 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -98,7 +98,7 @@ def _check_multivariate_indices(indices, n_chans): Returns ------- indices : tuple of array of array of int, shape of (2, n_cons, max_n_chans) - The indices padded with the invalid channel index ``-1``. + The indices as a masked array. Notes ----- @@ -112,26 +112,35 @@ def _check_multivariate_indices(indices, n_chans): connection must be unique. If the seed and target indices are given as lists or tuples, they will be - converted to numpy arrays. In case the number of channels differs across + converted to numpy arrays. Because the number of channels can differ across connections or between the seeds and targets for a given connection (i.e. - ragged indices), the returned array will be padded with the invalid channel - index ``-1`` according to the maximum number of channels in the seed or - target of any one connection. E.g. the ragged indices of shape ``(2, - n_cons, variable)``:: + ragged/jagged indices), the returned array will be padded out to a 'full' + array with an invalid index (``-1``) according to the maximum number of + channels in the seed or target of any one connection. These invalid + entries are then masked and returned as numpy masked arrays. E.g. the + ragged indices of shape ``(2, n_cons, variable)``:: indices = ([[0, 1], [0, 1 ]], # seeds [[2, 3], [4, 5, 6]]) # targets - would be returned as:: + would be padded to full arrays:: + + indices = ([[0, 1, -1], [0, 1, -1]], # seeds + [[2, 3, -1], [4, 5, 6]]) # targets + + to have shape ``(2, n_cons, max_n_chans)``, where ``max_n_chans = 3``. The + invalid entries are then masked:: - indices = (np.array([[0, 1, -1], [0, 1, -1]]), # seeds - np.array([[2, 3, -1], [4, 5, -1]])) # targets + indices = ([[0, 1, --], [0, 1, --]], # seeds + [[2, 3, --], [4, 5, 6]]) # targets - where the indices have been padded with ``-1`` to have shape ``(2, n_cons, - max_n_chans)``, where ``max_n_chans = 3``. More information on working with - multivariate indices and handling connections where the number of seeds and - targets are not equal can be found in the - :doc:`../auto_examples/handling_ragged_arrays` example. + In case "indices" contains negative values to index channels, these will be + converted to the corresponding positive-valued index before any masking is + applied. + + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. """ if not isinstance(indices, tuple) or len(indices) != 2: raise ValueError('indices must be a tuple of length 2') @@ -141,6 +150,7 @@ def _check_multivariate_indices(indices, n_chans): 'have the same length') n_cons = len(indices[0]) + invalid = -1 max_n_chans = 0 for group_idx, group in enumerate(indices): @@ -166,15 +176,19 @@ def _check_multivariate_indices(indices, n_chans): indices[group_idx][con_idx][chan_idx] = chan % n_chans # pad indices to avoid ragged arrays - padded_indices = (np.full((n_cons, max_n_chans), -1, dtype=np.int32), - np.full((n_cons, max_n_chans), -1, dtype=np.int32)) + padded_indices = (np.full((n_cons, max_n_chans), invalid, dtype=np.int32), + np.full((n_cons, max_n_chans), invalid, dtype=np.int32)) con_i = 0 for seed, target in zip(indices[0], indices[1]): padded_indices[0][con_i, :len(seed)] = seed padded_indices[1][con_i, :len(target)] = target con_i += 1 - return padded_indices + # mask invalid indices + masked_indices = (np.ma.masked_values(padded_indices[0], invalid), + np.ma.masked_values(padded_indices[1], invalid)) + + return masked_indices def seed_target_indices(seeds, targets): @@ -236,8 +250,8 @@ def seed_target_multivariate_indices(seeds, targets): Returns ------- - indices : tuple of array of array of int, shape (2, n_cons, max_n_chans) - The indices padded with the invalid channel index ``-1``. + indices : tuple of array of array of int, shape (2, n_cons, variable) + The indices as a numpy object array. Notes ----- @@ -247,12 +261,8 @@ def seed_target_multivariate_indices(seeds, targets): channels in the data. The length of indices for each connection do not need to be equal. Furthermore, all indices within a connection must be unique. - ``seeds`` and ``targets`` will be expanded such that connectivity will be - computed between each set of seeds and targets. In case the number of - channels differs across connections or between the seeds and targets for a - given connection (i.e. ragged indices), the returned array will be padded - with the invalid channel index ``-1`` according to the maximum number of - channels in the seed or target of any one connection. E.g. ``seeds`` and + Because the number of channels per connection can vary, the indices are + stored as numpy arrays with ``dtype=object``. E.g. ``seeds`` and ``targets``:: seeds = [[0]] @@ -260,15 +270,15 @@ def seed_target_multivariate_indices(seeds, targets): would be returned as:: - indices = (np.array([[0, -1, -1], [0, -1, -1]]), # seeds - np.array([[1, 2, -1], [3, 4, 5]])) # targets + indices = (np.array([[0 ], [0 ]], dtype=object), # seeds + np.array([[1, 2], [3, 4, 5]], dtype=object)) # targets + + Even if the number of channels does not vary, the indices will still be + stored as object arrays for compatibility. - where the indices have been padded with ``-1`` to have shape ``(2, n_cons, - max_n_chans)``, where ``n_cons = n_unique_seeds * n_unique_targets`` and - ``max_n_chans = 3``. More information on working with multivariate indices - and handling connections where the number of seeds and targets are not - equal can be found in the :doc:`../auto_examples/handling_ragged_arrays` - example. + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. """ array_like = (np.ndarray, list, tuple) @@ -278,7 +288,6 @@ def seed_target_multivariate_indices(seeds, targets): ): raise TypeError('`seeds` and `targets` must be array-like') - n_chans = [] for inds in [*seeds, *targets]: if not isinstance(inds, array_like): raise TypeError( @@ -286,27 +295,15 @@ def seed_target_multivariate_indices(seeds, targets): if len(inds) != len(np.unique(inds)): raise ValueError( '`seeds` and `targets` cannot contain repeated channels') - n_chans.append(len(inds)) - max_n_chans = max(n_chans) - n_cons = len(seeds) * len(targets) - # pad indices to avoid ragged arrays - padded_seeds = np.full((len(seeds), max_n_chans), -1, dtype=np.int32) - padded_targets = np.full((len(targets), max_n_chans), -1, dtype=np.int32) - for con_i, seed in enumerate(seeds): - padded_seeds[con_i, :len(seed)] = seed - for con_i, target in enumerate(targets): - padded_targets[con_i, :len(target)] = target - - # create final indices - indices = (np.zeros((n_cons, max_n_chans), dtype=np.int32), - np.zeros((n_cons, max_n_chans), dtype=np.int32)) - con_i = 0 - for seed in padded_seeds: - for target in padded_targets: - indices[0][con_i] = seed - indices[1][con_i] = target - con_i += 1 + indices = [[], []] + for seed in seeds: + for target in targets: + indices[0].append(np.array(seed)) + indices[1].append(np.array(target)) + + indices = (np.array(indices[0], dtype=object), + np.array(indices[1], dtype=object)) return indices From 1ad464ccf28c2ff5843f86e814d0568a6b359fdf Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 2 Nov 2023 19:12:15 +0100 Subject: [PATCH 36/40] removed redundant ignored word Co-authored-by: Eric Larson --- ignore_words.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/ignore_words.txt b/ignore_words.txt index 24f6f63c..304ce95a 100644 --- a/ignore_words.txt +++ b/ignore_words.txt @@ -1,5 +1,4 @@ nd adn ba -BA manuel From 334afef2aa98ee2d39d9d6d36d6145ca02e91f76 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 2 Nov 2023 19:12:49 +0100 Subject: [PATCH 37/40] removed redundant list creation Co-authored-by: Eric Larson --- mne_connectivity/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne_connectivity/base.py b/mne_connectivity/base.py index 672448d1..8123087c 100644 --- a/mne_connectivity/base.py +++ b/mne_connectivity/base.py @@ -483,9 +483,9 @@ def _prepare_xarray(self, data, names, indices, n_nodes, method, # set method, indices and n_nodes if isinstance(indices, tuple): - if all([isinstance(inds, np.ndarray) for inds in indices]): + if all(isinstance(inds, np.ndarray) for inds in indices): # leave multivariate indices as arrays for easier indexing - if all([inds.ndim > 1 for inds in indices]): + if all(inds.ndim > 1 for inds in indices): new_indices = (indices[0], indices[1]) else: new_indices = (list(indices[0]), list(indices[1])) From 4e44edef80cdd46fd999fa3e89462a0b8aaecc6e Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 2 Nov 2023 19:29:46 +0100 Subject: [PATCH 38/40] updated default non-zero rank tolerance --- examples/granger_causality.py | 16 ++++++++-------- examples/mic_mim.py | 16 ++++++++-------- mne_connectivity/spectral/epochs.py | 2 +- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/granger_causality.py b/examples/granger_causality.py index 64a657db..73c10a86 100644 --- a/examples/granger_causality.py +++ b/examples/granger_causality.py @@ -343,21 +343,21 @@ # an automatic rank computation is performed and an appropriate degree of # dimensionality reduction will be enforced. The rank of the data is determined # by computing the singular values of the data and finding those within a -# factor of :math:`1e^{-10}` relative to the largest singular value. +# factor of :math:`1e^{-6}` relative to the largest singular value. # -# In some circumstances, this threshold may be too lenient, in which case you -# should inspect the singular values of your data to identify an appropriate -# degree of dimensionality reduction to perform, which you can then specify -# manually using the ``rank`` argument. The code below shows one possible -# approach for finding an appropriate rank of close-to-singular data with a -# more conservative threshold of :math:`1e^{-5}`. +# Whilst unlikely, there may be scenarios in which this threshold may be too +# lenient. In these cases, you should inspect the singular values of your data +# to identify an appropriate degree of dimensionality reduction to perform, +# which you can then specify manually using the ``rank`` argument. The code +# below shows one possible approach for finding an appropriate rank of +# close-to-singular data with a more conservative threshold. # %% # gets the singular values of the data s = np.linalg.svd(raw.get_data(), compute_uv=False) # finds how many singular values are 'close' to the largest singular value -rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the 'closeness' criteria +rank = np.count_nonzero(s >= s[0] * 1e-4) # 1e-4 is the 'closeness' criteria ############################################################################### # Nonethless, even in situations where you specify an appropriate rank, it is diff --git a/examples/mic_mim.py b/examples/mic_mim.py index 87111586..86044969 100644 --- a/examples/mic_mim.py +++ b/examples/mic_mim.py @@ -372,21 +372,21 @@ # an automatic rank computation is performed and an appropriate degree of # dimensionality reduction will be enforced. The rank of the data is determined # by computing the singular values of the data and finding those within a -# factor of :math:`1e^{-10}` relative to the largest singular value. +# factor of :math:`1e^{-6}` relative to the largest singular value. # -# In some circumstances, this threshold may be too lenient, in which case you -# should inspect the singular values of your data to identify an appropriate -# degree of dimensionality reduction to perform, which you can then specify -# manually using the ``rank`` argument. The code below shows one possible -# approach for finding an appropriate rank of close-to-singular data with a -# more conservative threshold of :math:`1e^{-5}`. +# Whilst unlikely, there may be scenarios in which this threshold may be too +# lenient. In these cases, you should inspect the singular values of your data +# to identify an appropriate degree of dimensionality reduction to perform, +# which you can then specify manually using the ``rank`` argument. The code +# below shows one possible approach for finding an appropriate rank of +# close-to-singular data with a more conservative threshold. # %% # gets the singular values of the data s = np.linalg.svd(raw.get_data(), compute_uv=False) # finds how many singular values are 'close' to the largest singular value -rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the 'closeness' criteria +rank = np.count_nonzero(s >= s[0] * 1e-4) # 1e-4 is the 'closeness' criteria ############################################################################### diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 914f8531..63c32e26 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -198,7 +198,7 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, def _check_rank_input(rank, data, indices): """Check the rank argument is appropriate and compute rank if missing.""" - sv_tol = 1e-10 # tolerance for non-zero singular val (rel. to largest) + sv_tol = 1e-6 # tolerance for non-zero singular val (rel. to largest) if rank is None: rank = np.zeros((2, len(indices[0])), dtype=int) From 000018991e85ea25a94905e9c85535fc45ade9df Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Fri, 3 Nov 2023 13:46:20 +0100 Subject: [PATCH 39/40] switched to array indices & added inline comments --- mne_connectivity/spectral/epochs.py | 8 +++---- mne_connectivity/spectral/time.py | 8 +++---- mne_connectivity/tests/test_utils.py | 32 ++++++++++++++-------------- mne_connectivity/utils/utils.py | 18 ++++++---------- 4 files changed, 29 insertions(+), 37 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 63c32e26..f4899e87 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -105,6 +105,7 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, 'causality, as all-to-all connectivity is not supported') else: logger.info('using all indices for multivariate connectivity') + # indices expected to be a masked array, even if not ragged indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], np.arange(n_signals, dtype=int)[np.newaxis, :]) indices_use = np.ma.masked_array(indices_use, @@ -115,11 +116,8 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, indices_use = np.tril_indices(n_signals, -1) else: if multivariate_con: - # mask indices + # pad ragged indices and mask the invalid entries indices_use = _check_multivariate_indices(indices, n_signals) - indices_use = np.ma.concatenate([inds[np.newaxis] for inds in - indices_use]) - np.ma.set_fill_value(indices_use, -1) # else 99999 after concat. if any(this_method in _gc_methods for this_method in method): for seed, target in zip(indices_use[0], indices_use[1]): intersection = np.intersect1d(seed.compressed(), @@ -1908,7 +1906,7 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, # make sure padded indices are stored in the connectivity object if multivariate_con and indices is not None: - # create a copy + # create a copy so that `indices_use` can be modified indices = (indices_use[0].copy(), indices_use[1].copy()) # get the window function, wavelets, etc for different modes diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 4ac3f161..07dc4e57 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -407,6 +407,7 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, 'indices must be specified when computing Granger ' 'causality, as all-to-all connectivity is not supported') logger.info('using all indices for multivariate connectivity') + # indices expected to be a masked array, even if not ragged indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], np.arange(n_signals, dtype=int)[np.newaxis, :]) indices_use = np.ma.masked_array(indices_use, @@ -416,11 +417,8 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, indices_use = np.tril_indices(n_signals, k=-1) else: if multivariate_con: - # mask indices + # pad ragged indices and mask the invalid entries indices_use = _check_multivariate_indices(indices, n_signals) - indices_use = np.ma.concatenate([inds[np.newaxis] for inds in - indices_use]) - np.ma.set_fill_value(indices_use, -1) # else 99999 after concat. if any(this_method in _gc_methods for this_method in method): for seed, target in zip(indices_use[0], indices_use[1]): intersection = np.intersect1d(seed.compressed(), @@ -430,7 +428,7 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, 'seed and target indices must not intersect when ' 'computing Granger causality') # make sure padded indices are stored in the connectivity object - # create a copy + # create a copy so that `indices_use` can be modified indices = (indices_use[0].copy(), indices_use[1].copy()) else: indices_use = check_indices(indices) diff --git a/mne_connectivity/tests/test_utils.py b/mne_connectivity/tests/test_utils.py index caead679..d80f52a6 100644 --- a/mne_connectivity/tests/test_utils.py +++ b/mne_connectivity/tests/test_utils.py @@ -94,34 +94,34 @@ def test_check_multivariate_indices(): seeds = [[0, 1], [0, 1]] targets = [[2, 3], [3, 4]] indices = _check_multivariate_indices((seeds, targets), n_signals) - assert all(np.ma.isMA(inds) for inds in indices) - assert all(inds.fill_value == mask_value for inds in indices) - assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]), - np.array([[2, 3], [3, 4]]))) + assert np.ma.isMA(indices) + assert indices.fill_value == mask_value + assert np.all(indices == np.array([[[0, 1], [0, 1]], + [[2, 3], [3, 4]]])) # non-ragged indices with negative values seeds = [[0, 1], [0, 1]] targets = [[2, 3], [3, -1]] indices = _check_multivariate_indices((seeds, targets), n_signals) - assert all(np.ma.isMA(inds) for inds in indices) - assert all(inds.fill_value == mask_value for inds in indices) - assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]), - np.array([[2, 3], [3, 4]]))) + assert np.ma.isMA(indices) + assert indices.fill_value == mask_value + assert np.all(indices == np.array([[[0, 1], [0, 1]], + [[2, 3], [3, 4]]])) # ragged indices seeds = [[0, 1], [0, 1]] targets = [[2, 3, 4], [4]] indices = _check_multivariate_indices((seeds, targets), n_signals) - assert all(np.ma.isMA(inds) for inds in indices) - assert all(inds.fill_value == mask_value for inds in indices) - assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]), - np.array([[2, 3, 4], [4, -1, -1]]))) + assert np.ma.isMA(indices) + assert indices.fill_value == mask_value + assert np.all(indices == np.array([[[0, 1, -1], [0, 1, -1]], + [[2, 3, 4], [4, -1, -1]]])) # ragged indices with negative values seeds = [[0, 1], [0, 1]] targets = [[2, 3, 4], [-1]] indices = _check_multivariate_indices((seeds, targets), n_signals) - assert all(np.ma.isMA(inds) for inds in indices) - assert all(inds.fill_value == mask_value for inds in indices) - assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]), - np.array([[2, 3, 4], [4, -1, -1]]))) + assert np.ma.isMA(indices) + assert indices.fill_value == mask_value + assert np.all(indices == np.array([[[0, 1, -1], [0, 1, -1]], + [[2, 3, 4], [4, -1, -1]]])) # test error catching with pytest.raises(ValueError, diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index 2cab0007..1f278400 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -97,8 +97,8 @@ def _check_multivariate_indices(indices, n_chans): Returns ------- - indices : tuple of array of array of int, shape of (2, n_cons, max_n_chans) - The indices as a masked array. + indices : array of array of int, shape of (2, n_cons, max_n_chans) + The padded indices as a masked array. Notes ----- @@ -176,17 +176,13 @@ def _check_multivariate_indices(indices, n_chans): indices[group_idx][con_idx][chan_idx] = chan % n_chans # pad indices to avoid ragged arrays - padded_indices = (np.full((n_cons, max_n_chans), invalid, dtype=np.int32), - np.full((n_cons, max_n_chans), invalid, dtype=np.int32)) - con_i = 0 - for seed, target in zip(indices[0], indices[1]): - padded_indices[0][con_i, :len(seed)] = seed - padded_indices[1][con_i, :len(target)] = target - con_i += 1 + padded_indices = np.full((2, n_cons, max_n_chans), invalid, dtype=np.int32) + for con_i, (seed, target) in enumerate(zip(indices[0], indices[1])): + padded_indices[0, con_i, :len(seed)] = seed + padded_indices[1, con_i, :len(target)] = target # mask invalid indices - masked_indices = (np.ma.masked_values(padded_indices[0], invalid), - np.ma.masked_values(padded_indices[1], invalid)) + masked_indices = np.ma.masked_values(padded_indices, invalid) return masked_indices From 3e8125287208f985b655ff9cd74868745e82b16a Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Fri, 3 Nov 2023 13:57:22 +0100 Subject: [PATCH 40/40] fixed grammar mistake --- examples/handling_ragged_arrays.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/handling_ragged_arrays.py b/examples/handling_ragged_arrays.py index 8076b9e2..4c2d43b1 100644 --- a/examples/handling_ragged_arrays.py +++ b/examples/handling_ragged_arrays.py @@ -70,7 +70,7 @@ # about the multiple channel indices for each connection is stored in a nested # array. Importantly, these indices are ragged, as the first connection will be # computed between 2 seed and 3 target channels, and the second connection -# between 4 seed and 1 target channel. The connectivity functions will +# between 4 seed and 1 target channel(s). The connectivity functions will # recognise the indices as being ragged, and pad them to a 'full' array by # adding placeholder values which are masked accordingly. This makes the # indices easier to work with, and also compatible with the engine used to save