diff --git a/doc/api.rst b/doc/api.rst index 26ef14b6..300601b4 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -73,6 +73,7 @@ Post-processing on connectivity degree seed_target_indices + seed_target_multivariate_indices check_indices select_order diff --git a/doc/conf.py b/doc/conf.py index 2d4ae756..3dd9883f 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -91,7 +91,8 @@ 'n_node_names', 'n_tapers', 'n_signals', 'n_step', 'n_freqs', 'epochs', 'freqs', 'times', 'arrays', 'lists', 'func', 'n_nodes', 'n_estimated_nodes', 'n_samples', 'n_channels', 'Renderer', - 'n_ytimes', 'n_ychannels', 'n_events' + 'n_ytimes', 'n_ychannels', 'n_events', 'n_cons', 'max_n_chans', + 'n_unique_seeds', 'n_unique_targets', 'variable' } numpydoc_xref_aliases = { # Python diff --git a/examples/granger_causality.py b/examples/granger_causality.py index f5d8316d..73c10a86 100644 --- a/examples/granger_causality.py +++ b/examples/granger_causality.py @@ -145,7 +145,7 @@ ############################################################################### # We will focus on connectivity between sensors over the parietal and occipital -# cortices, with 20 parietal sensors designated as group A, and 20 occipital +# cortices, with 20 parietal sensors designated as group A, and 22 occipital # sensors designated as group B. # %% @@ -157,17 +157,8 @@ signals_b = [idx for idx, ch_info in enumerate(epochs.info['chs']) if ch_info['ch_name'][2] == 'O'] -# XXX: Currently ragged indices are not supported, so we only consider a single -# list of indices with an equal number of seeds and targets -min_n_chs = min(len(signals_a), len(signals_b)) -signals_a = signals_a[:min_n_chs] -signals_b = signals_b[:min_n_chs] - -indices_ab = (np.array(signals_a), np.array(signals_b)) # A => B -indices_ba = (np.array(signals_b), np.array(signals_a)) # B => A - -signals_a_names = [epochs.info['ch_names'][idx] for idx in signals_a] -signals_b_names = [epochs.info['ch_names'][idx] for idx in signals_b] +indices_ab = (np.array([signals_a]), np.array([signals_b])) # A => B +indices_ba = (np.array([signals_b]), np.array([signals_a])) # B => A # compute Granger causality gc_ab = spectral_connectivity_epochs( @@ -181,8 +172,8 @@ ############################################################################### # Plotting the results, we see that there is a flow of information from our -# parietal sensors (group A) to our occipital sensors (group B) with noticeable -# peaks at around 8, 18, and 26 Hz. +# parietal sensors (group A) to our occipital sensors (group B) with a +# noticeable peak at ~8 Hz, and smaller peaks at 18 and 26 Hz. # %% @@ -208,8 +199,7 @@ # # Doing so, we see that the flow of information across the spectrum remains # dominant from parietal to occipital sensors (indicated by the positive-valued -# Granger scores). However, the pattern of connectivity is altered, such as -# around 10 and 12 Hz where peaks of net information flow are now present. +# Granger scores), with similar peaks around 10, 18, and 26 Hz. # %% @@ -289,8 +279,8 @@ # Plotting the TRGC results, reveals a very different picture compared to net # GC. For one, there is now a dominance of information flow ~6 Hz from # occipital to parietal sensors (indicated by the negative-valued Granger -# scores). Additionally, the peaks ~10 Hz are less dominant in the spectrum, -# with parietal to occipital information flow between 13-20 Hz being much more +# scores). Additionally, the peak ~10 Hz is less dominant in the spectrum, with +# parietal to occipital information flow between 13-20 Hz being much more # prominent. The stark difference between net GC and TRGC results indicates # that the net GC spectrum was contaminated by spurious connectivity resulting # from source mixing or correlated noise in the recordings. Altogether, the use @@ -353,21 +343,21 @@ # an automatic rank computation is performed and an appropriate degree of # dimensionality reduction will be enforced. The rank of the data is determined # by computing the singular values of the data and finding those within a -# factor of :math:`1e^{-10}` relative to the largest singular value. +# factor of :math:`1e^{-6}` relative to the largest singular value. # -# In some circumstances, this threshold may be too lenient, in which case you -# should inspect the singular values of your data to identify an appropriate -# degree of dimensionality reduction to perform, which you can then specify -# manually using the ``rank`` argument. The code below shows one possible -# approach for finding an appropriate rank of close-to-singular data with a -# more conservative threshold of :math:`1e^{-5}`. +# Whilst unlikely, there may be scenarios in which this threshold may be too +# lenient. In these cases, you should inspect the singular values of your data +# to identify an appropriate degree of dimensionality reduction to perform, +# which you can then specify manually using the ``rank`` argument. The code +# below shows one possible approach for finding an appropriate rank of +# close-to-singular data with a more conservative threshold. # %% # gets the singular values of the data s = np.linalg.svd(raw.get_data(), compute_uv=False) -# finds how many singular values are "close" to the largest singular value -rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the "closeness" criteria +# finds how many singular values are 'close' to the largest singular value +rank = np.count_nonzero(s >= s[0] * 1e-4) # 1e-4 is the 'closeness' criteria ############################################################################### # Nonethless, even in situations where you specify an appropriate rank, it is @@ -387,7 +377,7 @@ try: spectral_connectivity_epochs( epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=None, - gc_n_lags=20) # A => B + gc_n_lags=20, verbose=False) # A => B print('Success!') except RuntimeError as error: print('\nCaught the following error:\n' + repr(error)) diff --git a/examples/handling_ragged_arrays.py b/examples/handling_ragged_arrays.py new file mode 100644 index 00000000..4c2d43b1 --- /dev/null +++ b/examples/handling_ragged_arrays.py @@ -0,0 +1,151 @@ +""" +========================================================= +Working with ragged indices for multivariate connectivity +========================================================= + +This example demonstrates how multivariate connectivity involving different +numbers of seeds and targets can be handled in MNE-Connectivity. +""" + +# Author: Thomas S. Binns +# License: BSD (3-clause) + +# %% + +import numpy as np + +from mne_connectivity import spectral_connectivity_epochs + +############################################################################### +# Background +# ---------- +# +# With multivariate connectivity, interactions between multiple signals can be +# considered together, and the number of signals designated as seeds and +# targets does not have to be equal within or across connections. Issues can +# arise from this when storing information associated with connectivity in +# arrays, as the number of entries within each dimension can vary within and +# across connections depending on the number of seeds and targets. Such arrays +# are 'ragged', and support for ragged arrays is limited in NumPy to the +# ``object`` datatype. Not only is working with ragged arrays is cumbersome, +# but saving arrays with ``dtype='object'`` is not supported by the h5netcdf +# engine used to save connectivity objects. The workaround used in +# MNE-Connectivity is to pad ragged arrays with some known values according to +# the largest number of entries in each dimension, such that there is an equal +# amount of information across and within connections for each dimension of the +# arrays. +# +# As an example, consider we have 5 channels and want to compute 2 connections: +# the first between channels in indices 0 and 1 with those in indices 2, 3, +# and 4; and the second between channels 0, 1, 2, and 3 with channel 4. The +# seed and target indices can be written as such:: +# +# seeds = [[0, 1 ], [0, 1, 2, 3]] +# targets = [[2, 3, 4], [4 ]] +# +# The ``indices`` parameter passed to +# :func:`~mne_connectivity.spectral_connectivity_epochs` and +# :func:`~mne_connectivity.spectral_connectivity_time` must be a tuple of +# array-likes, meaning +# that the indices can be passed as a tuple of: lists; tuples; or NumPy arrays. +# Examples of how ``indices`` can be formed are shown below:: +# +# # tuple of lists +# ragged_indices = ([[0, 1 ], [0, 1, 2, 3]], +# [[2, 3, 4], [4 ]]) +# +# # tuple of tuples +# ragged_indices = (((0, 1 ), (0, 1, 2, 3)), +# ((2, 3, 4), (4 ))) +# +# # tuple of arrays +# ragged_indices = (np.array([[0, 1 ], [0, 1, 2, 3]], dtype='object'), +# np.array([[2, 3, 4], [4 ]], dtype='object')) +# +# **N.B. Note that since NumPy v1.19.0, dtype='object' must be specified when +# forming ragged arrays.** +# +# Just as for bivariate connectivity, the length of ``indices[0]`` and +# ``indices[1]`` is equal (i.e. the number of connections), however information +# about the multiple channel indices for each connection is stored in a nested +# array. Importantly, these indices are ragged, as the first connection will be +# computed between 2 seed and 3 target channels, and the second connection +# between 4 seed and 1 target channel(s). The connectivity functions will +# recognise the indices as being ragged, and pad them to a 'full' array by +# adding placeholder values which are masked accordingly. This makes the +# indices easier to work with, and also compatible with the engine used to save +# connectivity objects. For example, the above indices would become:: +# +# padded_indices = (np.array([[0, 1, --, --], [0, 1, 2, 3]]), +# np.array([[2, 3, 4, --], [4, --, --, --]])) +# +# where ``--`` are masked entries. These indices are what is stored in the +# returned connectivity objects. +# +# For the connectivity results themselves, the methods available in +# MNE-Connectivity combine information across the different channels into a +# single (time-)frequency-resolved connectivity spectrum, regardless of the +# number of seed and target channels, so ragged arrays are not a concern here. +# However, the maximised imaginary part of coherency (MIC) method also returns +# spatial patterns of connectivity, which show the contribution of each channel +# to the dimensionality-reduced connectivity estimate (explained in more detail +# in :doc:`mic_mim`). Because these patterns are returned for each channel, +# their shape can vary depending on the number of seeds and targets in each +# connection, making them ragged. To avoid this, the patterns are padded along +# the channel axis with the known and invalid entry ``np.nan``, in line with +# that applied to ``indices``. Extracting only the valid spatial patterns from +# the connectivity object is trivial, as shown below: + +# %% + +# create random data +data = np.random.randn(10, 5, 200) # epochs x channels x times +sfreq = 50 +ragged_indices = ([[0, 1], [0, 1, 2, 3]], # seeds + [[2, 3, 4], [4]]) # targets + +# compute connectivity +con = spectral_connectivity_epochs( + data, method='mic', indices=ragged_indices, sfreq=sfreq, fmin=10, fmax=30, + verbose=False) +patterns = np.array(con.attrs['patterns']) +padded_indices = con.indices +n_freqs = con.get_data().shape[-1] +n_cons = len(ragged_indices[0]) +max_n_chans = max( + len(inds) for inds in ([*ragged_indices[0], *ragged_indices[1]])) + +# show that the padded indices entries are masked +assert np.sum(padded_indices[0][0].mask) == 2 # 2 padded channels +assert np.sum(padded_indices[1][0].mask) == 1 # 1 padded channels +assert np.sum(padded_indices[0][1].mask) == 0 # 0 padded channels +assert np.sum(padded_indices[1][1].mask) == 3 # 3 padded channels + +# patterns have shape [seeds/targets x cons x max channels x freqs (x times)] +assert patterns.shape == (2, n_cons, max_n_chans, n_freqs) + +# show that the padded patterns entries are all np.nan +assert np.all(np.isnan(patterns[0, 0, 2:])) # 2 padded channels +assert np.all(np.isnan(patterns[1, 0, 3:])) # 1 padded channels +assert not np.any(np.isnan(patterns[0, 1])) # 0 padded channels +assert np.all(np.isnan(patterns[1, 1, 1:])) # 3 padded channels + +# extract patterns for first connection using the ragged indices +seed_patterns_con1 = patterns[0, 0, :len(ragged_indices[0][0])] +target_patterns_con1 = patterns[1, 0, :len(ragged_indices[1][0])] + +# extract patterns for second connection using the padded, masked indices +seed_patterns_con2 = ( + patterns[0, 1, :padded_indices[0][1].count()]) +target_patterns_con2 = ( + patterns[1, 1, :padded_indices[1][1].count()]) + +# show that shapes of patterns are correct +assert seed_patterns_con1.shape == (2, n_freqs) # channels (0, 1) +assert target_patterns_con1.shape == (3, n_freqs) # channels (2, 3, 4) +assert seed_patterns_con2.shape == (4, n_freqs) # channels (0, 1, 2, 3) +assert target_patterns_con2.shape == (1, n_freqs) # channels (4) + +print('Assertions completed successfully!') + +# %% diff --git a/examples/mic_mim.py b/examples/mic_mim.py index 179ea620..86044969 100644 --- a/examples/mic_mim.py +++ b/examples/mic_mim.py @@ -70,7 +70,7 @@ ############################################################################### # We will focus on connectivity between sensors over the left and right # hemispheres, with 75 sensors in the left hemisphere designated as seeds, and -# 75 sensors in the right hemisphere designated as targets. +# 76 sensors in the right hemisphere designated as targets. # %% @@ -81,13 +81,7 @@ targets = [idx for idx, ch_info in enumerate(epochs.info['chs']) if ch_info['loc'][0] > 0] -# XXX: Currently ragged indices are not supported, so we only consider a single -# list of indices with an equal number of seeds and targets -min_n_chs = min(len(seeds), len(targets)) -seeds = seeds[:min_n_chs] -targets = targets[:min_n_chs] - -multivar_indices = (np.array(seeds), np.array(targets)) +multivar_indices = (np.array([seeds]), np.array([targets])) seed_names = [epochs.info['ch_names'][idx] for idx in seeds] target_names = [epochs.info['ch_names'][idx] for idx in targets] @@ -171,12 +165,11 @@ # # Here, we average across the patterns in the 13-18 Hz range. Plotting the # patterns shows that the greatest connectivity between the left and right -# hemispheres occurs at the posteromedial regions, based on the regions with -# the largest absolute values. Using the signs of the values, we can infer the -# existence of a dipole source in the central regions of the left hemisphere -# which may account for the connectivity contributions seen for the left -# posteromedial and frontolateral areas (represented on the plot as a green -# line). +# hemispheres occurs at the left and right posterior and left central regions, +# based on the areas with the largest absolute values. Using the signs of the +# values, we can infer the existence of a dipole source between the central and +# posterior regions of the left hemisphere accounting for the connectivity +# contributions (represented on the plot as a green line). # %% @@ -185,9 +178,9 @@ fband_idx = [mic.freqs.index(freq) for freq in fband] # patterns have shape [seeds/targets x cons x channels x freqs (x times)] -patterns = np.array(mic.attrs["patterns"]) -seed_pattern = patterns[0] -target_pattern = patterns[1] +patterns = np.array(mic.attrs['patterns']) +seed_pattern = patterns[0, :, :len(seeds)] +target_pattern = patterns[1, :, :len(targets)] # average across frequencies seed_pattern = np.mean(seed_pattern[0, :, fband_idx[0]:fband_idx[1] + 1], axis=1) @@ -217,7 +210,7 @@ # plot the left hemisphere dipole example axes[0].plot( - [-0.1, -0.05], [-0.075, -0.03], color='lime', linewidth=2, + [-0.01, -0.07], [-0.07, -0.03], color='lime', linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='k'), pe.Normal()]) plt.show() @@ -268,7 +261,7 @@ axis.set_ylabel('Absolute connectivity (A.U.)') fig.suptitle('Multivariate interaction measure') -n_channels = len(np.unique([*multivar_indices[0], *multivar_indices[1]])) +n_channels = len(seeds) + len(targets) normalised_mim = mim.get_data()[0] / n_channels print(f'Normalised MIM has a maximum value of {normalised_mim.max():.2f}') @@ -296,7 +289,7 @@ # %% -indices = (np.array([*seeds, *targets]), np.array([*seeds, *targets])) +indices = (np.array([[*seeds, *targets]]), np.array([[*seeds, *targets]])) gim = spectral_connectivity_epochs( epochs, method='mim', indices=indices, fmin=5, fmax=30, rank=None, verbose=False) @@ -307,7 +300,7 @@ axis.set_ylabel('Connectivity (A.U.)') fig.suptitle('Global interaction measure') -n_channels = len(np.unique([*indices[0], *indices[1]])) +n_channels = len(seeds) + len(targets) normalised_gim = gim.get_data()[0] / n_channels print(f'Normalised GIM has a maximum value of {normalised_gim.max():.2f}') @@ -369,9 +362,9 @@ # no. channels equal with and without projecting to rank subspace for patterns assert (patterns[0, 0].shape[0] == - np.array(mic_red.attrs["patterns"])[0, 0].shape[0]) + np.array(mic_red.attrs['patterns'])[0, 0].shape[0]) assert (patterns[1, 0].shape[0] == - np.array(mic_red.attrs["patterns"])[1, 0].shape[0]) + np.array(mic_red.attrs['patterns'])[1, 0].shape[0]) ############################################################################### @@ -379,21 +372,21 @@ # an automatic rank computation is performed and an appropriate degree of # dimensionality reduction will be enforced. The rank of the data is determined # by computing the singular values of the data and finding those within a -# factor of :math:`1e^{-10}` relative to the largest singular value. +# factor of :math:`1e^{-6}` relative to the largest singular value. # -# In some circumstances, this threshold may be too lenient, in which case you -# should inspect the singular values of your data to identify an appropriate -# degree of dimensionality reduction to perform, which you can then specify -# manually using the ``rank`` argument. The code below shows one possible -# approach for finding an appropriate rank of close-to-singular data with a -# more conservative threshold of :math:`1e^{-5}`. +# Whilst unlikely, there may be scenarios in which this threshold may be too +# lenient. In these cases, you should inspect the singular values of your data +# to identify an appropriate degree of dimensionality reduction to perform, +# which you can then specify manually using the ``rank`` argument. The code +# below shows one possible approach for finding an appropriate rank of +# close-to-singular data with a more conservative threshold. # %% # gets the singular values of the data s = np.linalg.svd(raw.get_data(), compute_uv=False) -# finds how many singular values are "close" to the largest singular value -rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the "closeness" criteria +# finds how many singular values are 'close' to the largest singular value +rank = np.count_nonzero(s >= s[0] * 1e-4) # 1e-4 is the 'closeness' criteria ############################################################################### diff --git a/ignore_words.txt b/ignore_words.txt index c3614a01..304ce95a 100644 --- a/ignore_words.txt +++ b/ignore_words.txt @@ -1,3 +1,4 @@ nd adn +ba manuel diff --git a/mne_connectivity/__init__.py b/mne_connectivity/__init__.py index 57aeff7f..43fe0793 100644 --- a/mne_connectivity/__init__.py +++ b/mne_connectivity/__init__.py @@ -17,4 +17,5 @@ from .io import read_connectivity from .spectral import spectral_connectivity_time, spectral_connectivity_epochs from .vector_ar import vector_auto_regression, select_order -from .utils import check_indices, degree, seed_target_indices +from .utils import (check_indices, degree, seed_target_indices, + seed_target_multivariate_indices) diff --git a/mne_connectivity/base.py b/mne_connectivity/base.py index 88951529..8123087c 100644 --- a/mne_connectivity/base.py +++ b/mne_connectivity/base.py @@ -483,7 +483,14 @@ def _prepare_xarray(self, data, names, indices, n_nodes, method, # set method, indices and n_nodes if isinstance(indices, tuple): - new_indices = (list(indices[0]), list(indices[1])) + if all(isinstance(inds, np.ndarray) for inds in indices): + # leave multivariate indices as arrays for easier indexing + if all(inds.ndim > 1 for inds in indices): + new_indices = (indices[0], indices[1]) + else: + new_indices = (list(indices[0]), list(indices[1])) + else: + new_indices = (list(indices[0]), list(indices[1])) indices = new_indices kwargs['method'] = method kwargs['indices'] = indices diff --git a/mne_connectivity/io.py b/mne_connectivity/io.py index 2c99eebc..e8d9b916 100644 --- a/mne_connectivity/io.py +++ b/mne_connectivity/io.py @@ -53,6 +53,12 @@ def _xarray_to_conn(array, cls_func): event_id = dict(zip(event_id_keys, event_id_vals)) array.attrs['event_id'] = event_id + # convert indices numpy arrays to a tuple of masked arrays + # (only multivariate connectivity indices saved as arrays) + if isinstance(array.attrs['indices'], np.ndarray): + array.attrs['indices'] = tuple( + np.ma.masked_values(array.attrs['indices'], -1)) + # create the connectivity class conn = cls_func( data=data, names=names, metadata=metadata, **array.attrs diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index eb766f06..f4899e87 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -24,7 +24,7 @@ ProgressBar, _arange_div, _check_option, _time_mask, logger, warn, verbose) from ..base import (SpectralConnectivity, SpectroTemporalConnectivity) -from ..utils import fill_doc, check_indices +from ..utils import fill_doc, check_indices, _check_multivariate_indices def _compute_freqs(n_times, sfreq, cwt_freqs, mode): @@ -92,40 +92,45 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, times = times_in[tmin_idx:tmax_idx] n_times = len(times) + if any(this_method in _multivariate_methods for this_method in method): + multivariate_con = True + else: + multivariate_con = False + if indices is None: - if any(this_method in _multivariate_methods for this_method in method): + if multivariate_con: if any(this_method in _gc_methods for this_method in method): raise ValueError( 'indices must be specified when computing Granger ' 'causality, as all-to-all connectivity is not supported') else: logger.info('using all indices for multivariate connectivity') - indices_use = (np.arange(n_signals, dtype=int), - np.arange(n_signals, dtype=int)) + # indices expected to be a masked array, even if not ragged + indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], + np.arange(n_signals, dtype=int)[np.newaxis, :]) + indices_use = np.ma.masked_array(indices_use, + mask=False, fill_value=-1) else: logger.info('only using indices for lower-triangular matrix') # only compute r for lower-triangular region indices_use = np.tril_indices(n_signals, -1) else: - if any(this_method in _gc_methods for this_method in method): - if set(indices[0]).intersection(indices[1]): - raise ValueError( - 'seed and target indices must not intersect when computing' - 'Granger causality') - indices_use = check_indices(indices) + if multivariate_con: + # pad ragged indices and mask the invalid entries + indices_use = _check_multivariate_indices(indices, n_signals) + if any(this_method in _gc_methods for this_method in method): + for seed, target in zip(indices_use[0], indices_use[1]): + intersection = np.intersect1d(seed.compressed(), + target.compressed()) + if intersection.size > 0: + raise ValueError( + 'seed and target indices must not intersect when ' + 'computing Granger causality') + else: + indices_use = check_indices(indices) # number of connectivities to compute - if any(this_method in _multivariate_methods for this_method in method): - if ( - len(np.unique(indices_use[0])) != len(indices_use[0]) or - len(np.unique(indices_use[1])) != len(indices_use[1]) - ): - raise ValueError( - 'seed and target indices cannot contain repeated channels for ' - 'multivariate connectivity') - n_cons = 1 # UNTIL RAGGED ARRAYS SUPPORTED - else: - n_cons = len(indices_use[0]) + n_cons = len(indices_use[0]) logger.info(' computing connectivity for %d connections' % n_cons) @@ -189,6 +194,46 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, n_signals, indices_use, warn_times) +def _check_rank_input(rank, data, indices): + """Check the rank argument is appropriate and compute rank if missing.""" + sv_tol = 1e-6 # tolerance for non-zero singular val (rel. to largest) + if rank is None: + rank = np.zeros((2, len(indices[0])), dtype=int) + + if isinstance(data, BaseEpochs): + data_arr = data.get_data() + else: + data_arr = data + + for group_i in range(2): # seeds and targets + for con_i, con_idcs in enumerate(indices[group_i]): + s = np.linalg.svd(data_arr[:, con_idcs.compressed()], + compute_uv=False) + rank[group_i][con_i] = np.min( + [np.count_nonzero(epoch >= epoch[0] * sv_tol) + for epoch in s]) + + logger.info('Estimated data ranks:') + con_i = 1 + for seed_rank, target_rank in zip(rank[0], rank[1]): + logger.info(' connection %i - seeds (%i); targets (%i)' + % (con_i, seed_rank, target_rank, )) + con_i += 1 + rank = tuple((np.array(rank[0]), np.array(rank[1]))) + + else: + for seed_idcs, target_idcs, seed_rank, target_rank in zip( + indices[0], indices[1], rank[0], rank[1]): + if not (0 < seed_rank <= len(seed_idcs) and + 0 < target_rank <= len(target_idcs)): + raise ValueError( + 'ranks for seeds and targets must be > 0 and <= the ' + 'number of channels in the seeds and targets, ' + 'respectively, for each connection') + + return rank + + def _assemble_spectral_params(mode, n_times, mt_adaptive, mt_bandwidth, sfreq, mt_low_bias, cwt_n_cycles, cwt_freqs, freqs, freq_mask): @@ -422,22 +467,23 @@ def compute_con(self, indices, ranks, n_epochs=1): if self.name == 'MIC': self.patterns = np.full( - (2, self.n_cons, len(indices[0]), self.n_freqs, n_times), + (2, self.n_cons, indices[0].shape[1], self.n_freqs, n_times), np.nan) con_i = 0 for seed_idcs, target_idcs, seed_rank, target_rank in zip( - [indices[0]], [indices[1]], ranks[0], ranks[1]): + indices[0], indices[1], ranks[0], ranks[1]): self._log_connection_number(con_i) - n_seeds = len(seed_idcs) + seed_idcs = seed_idcs.compressed() + target_idcs = target_idcs.compressed() con_idcs = [*seed_idcs, *target_idcs] C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] # Eqs. 32 & 33 C_bar, U_bar_aa, U_bar_bb = self._csd_svd( - C, n_seeds, seed_rank, target_rank) + C, seed_idcs, seed_rank, target_rank) # Eqs. 3 & 4 E = self._compute_e(C_bar, n_seeds=U_bar_aa.shape[3]) @@ -452,10 +498,11 @@ def compute_con(self, indices, ranks, n_epochs=1): self.reshape_results() - def _csd_svd(self, csd, n_seeds, seed_rank, target_rank): + def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): """Dimensionality reduction of CSD with SVD.""" n_times = csd.shape[0] - n_targets = csd.shape[2] - n_seeds + n_seeds = len(seed_idcs) + n_targets = csd.shape[3] - n_seeds C_aa = csd[..., :n_seeds, :n_seeds] C_ab = csd[..., :n_seeds, n_seeds:] @@ -505,8 +552,9 @@ def _compute_e(self, csd, n_seeds): for block_i in ProgressBar( range(self.n_steps), mesg="frequency blocks"): freqs = self._get_block_indices(block_i, self.n_freqs) - parallel(parallel_compute_t( + T[:, freqs] = np.array(parallel(parallel_compute_t( C_r[:, f], T[:, f], n_seeds) for f in freqs) + ).transpose(1, 0, 2, 3) if not np.isreal(T).all() or not np.isfinite(T).all(): raise RuntimeError( @@ -526,6 +574,7 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, U_bar_bb, con_i): """Compute MIC and the associated spatial patterns.""" n_seeds = len(seed_idcs) + n_targets = len(target_idcs) times = np.arange(n_times) freqs = np.arange(self.n_freqs) @@ -534,12 +583,15 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, np.matmul(E, E.transpose(0, 1, 3, 2))) w_targets, V_targets = np.linalg.eigh( np.matmul(E.transpose(0, 1, 3, 2), E)) - if np.all(seed_idcs == target_idcs): + if ( + len(seed_idcs) == len(target_idcs) and + np.all(np.sort(seed_idcs) == np.sort(target_idcs)) + ): # strange edge-case where the eigenvectors returned should be a set # of identity matrices with one rotated by 90 degrees, but are # instead identical (i.e. are not rotated versions of one another). # This leads to the case where the spatial filters are incorrectly - # applied, resulting in connectivity estimates of e.g. ~0 when they + # applied, resulting in connectivity estimates of ~0 when they # should be perfectly correlated ~1. Accordingly, we manually # create a set of rotated identity matrices to use as the filters. create_filter = False @@ -564,12 +616,12 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, beta = V_targets[times[:, None], freqs, :, w_targets.argmax(axis=2)] # Eq. 46 (seed spatial patterns) - self.patterns[0, con_i] = (np.matmul( + self.patterns[0, con_i, :n_seeds] = (np.matmul( np.real(C[..., :n_seeds, :n_seeds]), np.matmul(U_bar_aa, np.expand_dims(alpha, axis=3))))[..., 0].T # Eq. 47 (target spatial patterns) - self.patterns[1, con_i] = (np.matmul( + self.patterns[1, con_i, :n_targets] = (np.matmul( np.real(C[..., n_seeds:, n_seeds:]), np.matmul(U_bar_bb, np.expand_dims(beta, axis=3))))[..., 0].T @@ -586,7 +638,10 @@ def _compute_mim(self, E, seed_idcs, target_idcs, con_i): E, E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T # Eq. 15 - if all(np.unique(seed_idcs) == np.unique(target_idcs)): + if ( + len(seed_idcs) == len(target_idcs) and + np.all(np.sort(seed_idcs) == np.sort(target_idcs)) + ): self.con_scores[con_i] *= 0.5 def reshape_results(self): @@ -598,7 +653,7 @@ def reshape_results(self): def _mic_mim_compute_t(C, T, n_seeds): - """Compute T in place for a single frequency (used for MIC and MIM).""" + """Compute T for a single frequency (used for MIC and MIM).""" for time_i in range(C.shape[0]): T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power( C[time_i, :n_seeds, :n_seeds], -0.5 @@ -607,6 +662,8 @@ def _mic_mim_compute_t(C, T, n_seeds): C[time_i, n_seeds:, n_seeds:], -0.5 ) + return T + class _MICEst(_MultivariateCohEstBase): """Multivariate imaginary part of coherency (MIC) estimator.""" @@ -889,17 +946,16 @@ def compute_con(self, indices, ranks, n_epochs=1): con_i = 0 for seed_idcs, target_idcs, seed_rank, target_rank in zip( - [indices[0]], [indices[1]], ranks[0], ranks[1]): + indices[0], indices[1], ranks[0], ranks[1]): self._log_connection_number(con_i) + seed_idcs = seed_idcs.compressed() + target_idcs = target_idcs.compressed() con_idcs = [*seed_idcs, *target_idcs] - C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] - con_seeds = np.arange(len(seed_idcs)) - con_targets = np.arange(len(target_idcs)) + len(seed_idcs) + C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] - C_bar = self._csd_svd( - C, con_seeds, con_targets, seed_rank, target_rank) + C_bar = self._csd_svd(C, seed_idcs, seed_rank, target_rank) n_signals = seed_rank + target_rank con_seeds = np.arange(seed_rank) con_targets = np.arange(target_rank) + seed_rank @@ -921,13 +977,13 @@ def compute_con(self, indices, ranks, n_epochs=1): self.reshape_results() - def _csd_svd(self, csd, seeds, targets, seed_rank, target_rank): + def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): """Dimensionality reduction of CSD with SVD on the covariance.""" # sum over times and epochs to get cov. from CSD cov = csd.sum(axis=(0, 1)) - n_seeds = len(seeds) - n_targets = len(targets) + n_seeds = len(seed_idcs) + n_targets = csd.shape[3] - n_seeds cov_aa = cov[:n_seeds, :n_seeds] cov_bb = cov[n_seeds:, n_seeds:] @@ -1202,7 +1258,7 @@ def _gc_compute_H(A, C, K, z_k, I_n, I_m): See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: 10.1103/PhysRevE.91.040101, Eq. 4. """ - from scipy import linalg # is this necessary??? + from scipy import linalg # XXX: is this necessary??? H = np.zeros((A.shape[0], C.shape[1], C.shape[1]), dtype=np.complex128) for t in range(A.shape[0]): H[t] = I_n + np.matmul( @@ -1231,16 +1287,15 @@ class _GCTREst(_GCEstBase): def _epoch_spectral_connectivity(data, sig_idx, tmin_idx, tmax_idx, sfreq, method, mode, window_fun, eigvals, wavelets, - freq_mask, mt_adaptive, idx_map, block_size, - psd, accumulate_psd, con_method_types, - con_methods, n_signals, n_signals_use, - n_times, gc_n_lags, accumulate_inplace=True): + freq_mask, mt_adaptive, idx_map, n_cons, + block_size, psd, accumulate_psd, + con_method_types, con_methods, n_signals, + n_signals_use, n_times, gc_n_lags, + accumulate_inplace=True): """Estimate connectivity for one epoch (see spectral_connectivity).""" if any(this_method in _multivariate_methods for this_method in method): - n_cons = 1 # UNTIL RAGGED ARRAYS SUPPORTED n_con_signals = n_signals_use ** 2 else: - n_cons = len(idx_map[0]) n_con_signals = n_cons if wavelets is not None: @@ -1513,10 +1568,13 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, 'mim', 'gc', 'gc_tr]``) cannot be called with the other methods. indices : tuple of array | None Two arrays with indices of connections for which to compute - connectivity. If a multivariate method is called, the indices are for a - single connection between all seeds and all targets. If None, all - connections are computed, unless a Granger causality method is called, - in which case an error is raised. + connectivity. If a bivariate method is called, each array for the seeds + and targets should contain the channel indices for each bivariate + connection. If a multivariate method is called, each array for the + seeds and targets should consist of nested arrays containing + the channel indices for each multivariate connection. If ``None``, + connections between all channels are computed, unless a Granger + causality method is called, in which case an error is raised. sfreq : float The sampling frequency. Required if data is not :class:`Epochs `. @@ -1582,14 +1640,15 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, con : array | list of array Computed connectivity measure(s). Either an instance of ``SpectralConnectivity`` or ``SpectroTemporalConnectivity``. - The shape of each connectivity dataset is either - (n_signals ** 2, n_freqs) mode: 'multitaper' or 'fourier' - (n_signals ** 2, n_freqs, n_times) mode: 'cwt_morlet' - when "indices" is None, or - (n_con, n_freqs) mode: 'multitaper' or 'fourier' - (n_con, n_freqs, n_times) mode: 'cwt_morlet' - when "indices" is specified and "n_con = len(indices[0])". If a - multivariate method is called "n_con = 1" even if "indices" is None. + The shape of the connectivity result will be: + + - ``(n_cons, n_freqs)`` for multitaper or fourier modes + - ``(n_cons, n_freqs, n_times)`` for cwt_morlet mode + - ``n_cons = n_signals ** 2`` for bivariate methods with + ``indices=None`` + - ``n_cons = 1`` for multivariate methods with ``indices=None`` + - ``n_cons = len(indices[0])`` for bivariate and multivariate methods + when indices is supplied. See Also -------- @@ -1635,13 +1694,19 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, are in the same order as defined indices. For multivariate methods, this is handled differently. If "indices" is - None, connectivity between all signals will attempt to be computed (this is - not possible if a Granger causality method is called). If "indices" is - specified, the seeds and targets are treated as a single connection. For - example, to compute the connectivity between signals 0, 1, 2 and 3, 4, 5, - one would use the same approach as above, however the signals would all be - considered for a single connection and the connectivity scores would have - the shape (1, n_freqs). + None, connectivity between all signals will be computed and a single + connectivity spectrum will be returned (this is not possible if a Granger + causality method is called). If "indices" is specified, seed and target + indices for each connection should be specified as nested array-likes. For + example, to compute the connectivity between signals (0, 1) -> (2, 3) and + (0, 1) -> (4, 5), indices should be specified as:: + + indices = (np.array([[0, 1], [0, 1]]), # seeds + np.array([[2, 3], [4, 5]])) # targets + + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. **Supported Connectivity Measures** @@ -1834,11 +1899,16 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, # check rank input and compute data ranks if necessary if multivariate_con: - rank = _check_rank_input(rank, data, sfreq, indices_use) + rank = _check_rank_input(rank, data, indices_use) else: rank = None gc_n_lags = None + # make sure padded indices are stored in the connectivity object + if multivariate_con and indices is not None: + # create a copy so that `indices_use` can be modified + indices = (indices_use[0].copy(), indices_use[1].copy()) + # get the window function, wavelets, etc for different modes (spectral_params, mt_adaptive, n_times_spectrum, n_tapers) = _assemble_spectral_params( @@ -1848,16 +1918,25 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, cwt_freqs=cwt_freqs, freqs=freqs, freq_mask=freq_mask) # unique signals for which we actually need to compute PSD etc. - sig_idx = np.unique(np.r_[indices_use[0], indices_use[1]]) + if multivariate_con: + sig_idx = np.unique(indices_use.compressed()) + remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(sig_idx)} + remapped_inds = indices_use.copy() + for idx in sig_idx: + remapped_inds[indices_use == idx] = remapping[idx] + remapped_sig = np.unique(remapped_inds.compressed()) + else: + sig_idx = np.unique(np.r_[indices_use[0], indices_use[1]]) n_signals_use = len(sig_idx) # map indices to unique indices - idx_map = [np.searchsorted(sig_idx, ind) for ind in indices_use] if multivariate_con: - indices_use = idx_map - idx_map = np.unique([*idx_map[0], *idx_map[1]]) - idx_map = [np.sort(np.repeat(idx_map, len(sig_idx))), - np.tile(idx_map, len(sig_idx))] + indices_use = remapped_inds # use remapped seeds & targets + idx_map = [np.sort(np.repeat(remapped_sig, len(sig_idx))), + np.tile(remapped_sig, len(sig_idx))] + else: + idx_map = [ + np.searchsorted(sig_idx, ind) for ind in indices_use] # allocate space to accumulate PSD if accumulate_psd: @@ -1894,7 +1973,7 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, call_params = dict( sig_idx=sig_idx, tmin_idx=tmin_idx, tmax_idx=tmax_idx, sfreq=sfreq, method=method, mode=mode, freq_mask=freq_mask, idx_map=idx_map, - block_size=block_size, + n_cons=n_cons, block_size=block_size, psd=psd, accumulate_psd=accumulate_psd, mt_adaptive=mt_adaptive, con_method_types=con_method_types, @@ -1978,8 +2057,8 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, this_con = this_con_bands if this_patterns is not None: - patterns_shape = ((2, n_cons, len(indices[0]), n_bands) + - this_patterns.shape[4:]) + patterns_shape = list(this_patterns.shape) + patterns_shape[3] = n_bands this_patterns_bands = np.empty(patterns_shape, dtype=this_patterns.dtype) for band_idx in range(n_bands): @@ -2023,11 +2102,6 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, # number of nodes in the original data n_nodes = n_signals - if multivariate_con: - # UNTIL RAGGED ARRAYS SUPPORTED - indices = tuple( - [[np.array(indices_use[0])], [np.array(indices_use[1])]]) - # create a list of connectivity containers conn_list = [] for _con, _patterns, _method in zip(con, patterns, method): @@ -2054,46 +2128,3 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, conn_list = conn_list[0] return conn_list - - -def _check_rank_input(rank, data, sfreq, indices): - """Check the rank argument is appropriate and compute rank if missing.""" - # UNTIL RAGGED ARRAYS SUPPORTED - indices = np.array([[indices[0]], [indices[1]]]) - - if rank is None: - - rank = np.zeros((2, len(indices[0])), dtype=int) - - if isinstance(data, BaseEpochs): - data_arr = data.get_data() - else: - data_arr = data - - for group_i in range(2): - for con_i, con_idcs in enumerate(indices[group_i]): - s = np.linalg.svd(data_arr[:, con_idcs], compute_uv=False) - rank[group_i][con_i] = np.min( - [np.count_nonzero(epoch >= epoch[0] * 1e-10) - for epoch in s]) - - logger.info('Estimated data ranks:') - con_i = 1 - for seed_rank, target_rank in zip(rank[0], rank[1]): - logger.info(' connection %i - seeds (%i); targets (%i)' - % (con_i, seed_rank, target_rank, )) - con_i += 1 - - rank = tuple((np.array(rank[0]), np.array(rank[1]))) - - else: - for seed_idcs, target_idcs, seed_rank, target_rank in zip( - indices[0], indices[1], rank[0], rank[1]): - if not (0 < seed_rank <= len(seed_idcs) and - 0 < target_rank <= len(target_idcs)): - raise ValueError( - 'ranks for seeds and targets must be > 0 and <= the ' - 'number of channels in the seeds and targets, ' - 'respectively, for each connection') - - return rank diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index c8f6ab23..a514382b 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -426,7 +426,9 @@ def test_spectral_connectivity_epochs_multivariate(method): trans_bandwidth = 2.0 # Hz delay = 10 # samples (non-zero delay needed for ImCoh and GC to be >> 0) - indices = tuple([np.arange(n_seeds), np.arange(n_seeds) + n_seeds]) + indices = (np.arange(n_seeds)[np.newaxis, :], + np.arange(n_seeds)[np.newaxis, :] + n_seeds) + n_targets = n_seeds # 15-25 Hz connectivity fstart, fend = 15.0, 25.0 @@ -497,8 +499,17 @@ def test_spectral_connectivity_epochs_multivariate(method): if method in ['mic', 'mim']: con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=None, sfreq=sfreq) - assert (np.array(con.indices).tolist() == - [[[0, 1, 2, 3]], [[0, 1, 2, 3]]]) + assert con.indices is None + assert con.n_nodes == n_signals + if method == 'mic': + assert np.array(con.attrs['patterns']).shape[2] == n_signals + + # check ragged indices padded correctly + ragged_indices = (np.array([[0]]), np.array([[1, 2]])) + con = spectral_connectivity_epochs( + data, method=method, mode=mode, indices=ragged_indices, sfreq=sfreq) + assert np.all(np.array(con.indices) == + np.array([np.array([[0, -1]]), np.array([[1, 2]])])) # check shape of MIC patterns if method == 'mic': @@ -510,12 +521,12 @@ def test_spectral_connectivity_epochs_multivariate(method): if mode == 'cwt_morlet': patterns_shape = ( - (len(indices[0]), len(con.freqs), len(con.times)), - (len(indices[1]), len(con.freqs), len(con.times))) + (n_seeds, len(con.freqs), len(con.times)), + (n_targets, len(con.freqs), len(con.times))) else: patterns_shape = ( - (len(indices[0]), len(con.freqs)), - (len(indices[1]), len(con.freqs))) + (n_seeds, len(con.freqs)), + (n_targets, len(con.freqs))) assert np.shape(con.attrs["patterns"][0][0]) == patterns_shape[0] assert np.shape(con.attrs["patterns"][1][0]) == patterns_shape[1] @@ -535,10 +546,22 @@ def test_spectral_connectivity_epochs_multivariate(method): con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=rank) - assert (np.shape(con.attrs["patterns"][0][0])[0] == - len(indices[0])) - assert (np.shape(con.attrs["patterns"][1][0])[0] == - len(indices[1])) + assert (np.shape(con.attrs["patterns"][0][0])[0] == n_seeds) + assert (np.shape(con.attrs["patterns"][1][0])[0] == n_targets) + + # check patterns padded correctly + ragged_indices = (np.array([[0]]), np.array([[1, 2]])) + con = spectral_connectivity_epochs( + data, method=method, mode=mode, indices=ragged_indices, + sfreq=sfreq) + patterns = np.array(con.attrs["patterns"]) + patterns_shape = ( + (n_seeds, len(con.freqs)), (n_targets, len(con.freqs))) + assert patterns[0, 0].shape == patterns_shape[0] + assert patterns[1, 0].shape == patterns_shape[1] + assert not np.any(np.isnan(patterns[0, 0, 0])) + assert np.all(np.isnan(patterns[0, 0, 1])) + assert not np.any(np.isnan(patterns[1, 0])) def test_multivariate_spectral_connectivity_epochs_regression(): @@ -561,7 +584,7 @@ def test_multivariate_spectral_connectivity_epochs_regression(): data = pd.read_pickle( os.path.join(fpath, 'data', 'example_multivariate_data.pkl')) sfreq = 100 - indices = tuple([[0, 1], [2, 3]]) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) methods = ['mic', 'mim', 'gc', 'gc_tr'] con = spectral_connectivity_epochs( data, method=methods, indices=indices, mode='multitaper', sfreq=sfreq, @@ -590,13 +613,21 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): n_times = 256 rng = np.random.RandomState(0) data = rng.randn(n_epochs, n_signals, n_times) - indices = (np.arange(0, 2), np.arange(2, 4)) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) cwt_freqs = np.arange(10, 25 + 1) - # check bad indices with repeated channels + # check bad indices without nested array caught + with pytest.raises(TypeError, + match='multivariate indices must contain array-likes'): + non_nested_indices = (np.array([0, 1]), np.array([2, 3])) + spectral_connectivity_epochs( + data, method=method, mode=mode, indices=non_nested_indices, + sfreq=sfreq, gc_n_lags=10) + + # check bad indices with repeated channels caught with pytest.raises(ValueError, - match='seed and target indices cannot contain'): - repeated_indices = tuple([[0, 1, 1], [2, 2, 3]]) + match='multivariate indices cannot contain repeated'): + repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) spectral_connectivity_epochs( data, method=method, mode=mode, indices=repeated_indices, sfreq=sfreq, gc_n_lags=10) @@ -647,7 +678,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): sfreq=sfreq, rank=(np.array([2]), np.array([2])), cwt_freqs=cwt_freqs) - # only check these once for speed + # only check these once (e.g. only with multitaper) for speed if method == 'gc' and mode == 'multitaper': # check bad n_lags caught frange = (5, 10) @@ -665,7 +696,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): cwt_freqs=cwt_freqs) # check intersecting indices caught - bad_indices = (np.array([0, 1]), np.array([0, 2])) + bad_indices = (np.array([[0, 1]]), np.array([[0, 2]])) with pytest.raises(ValueError, match='seed and target indices must not intersect'): spectral_connectivity_epochs(data, method=method, mode=mode, @@ -698,7 +729,7 @@ def test_multivar_spectral_connectivity_parallel(method): n_times = 256 rng = np.random.RandomState(0) data = rng.randn(n_epochs, n_signals, n_times) - indices = (np.arange(0, 2), np.arange(2, 4)) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) spectral_connectivity_epochs( data, method=method, mode="multitaper", indices=indices, sfreq=sfreq, @@ -708,6 +739,53 @@ def test_multivar_spectral_connectivity_parallel(method): indices=indices, sfreq=sfreq, gc_n_lags=10, n_jobs=2) +def test_multivar_spectral_connectivity_flipped_indices(): + """Test multivar. indices structure maintained by connectivity methods.""" + sfreq = 50. + n_signals = 4 + n_epochs = 8 + n_times = 256 + rng = np.random.RandomState(0) + data = rng.randn(n_epochs, n_signals, n_times) + freqs = np.arange(10, 20) + + # if we're not careful, when finding the channels we need to compute the + # CSD for, we might accidentally reorder the connectivity indices + indices = (np.array([[0, 1]]), + np.array([[2, 3]])) + flipped_indices = (np.array([[2, 3]]), + np.array([[0, 1]])) + concat_indices = (np.array([[0, 1], [2, 3]]), + np.array([[2, 3], [0, 1]])) + + # we test on GC since this is a directed connectivity measure + method = 'gc' + + con_st = spectral_connectivity_epochs( # seed -> target + data, method=method, indices=indices, sfreq=sfreq, gc_n_lags=10) + con_ts = spectral_connectivity_epochs( # target -> seed + data, method=method, indices=flipped_indices, sfreq=sfreq, + gc_n_lags=10) + con_st_ts = spectral_connectivity_epochs( # seed -> target; target -> seed + data, method=method, indices=concat_indices, sfreq=sfreq, gc_n_lags=10) + assert not np.all(con_st.get_data() == con_ts.get_data()) + assert np.all(con_st.get_data()[0] == con_st_ts.get_data()[0]) + assert np.all(con_ts.get_data()[0] == con_st_ts.get_data()[1]) + + con_st = spectral_connectivity_time( # seed -> target + data, freqs, method=method, indices=indices, sfreq=sfreq, + gc_n_lags=10) + con_ts = spectral_connectivity_time( # target -> seed + data, freqs, method=method, indices=flipped_indices, sfreq=sfreq, + gc_n_lags=10) + con_st_ts = spectral_connectivity_time( # seed -> target; target -> seed + data, freqs, method=method, indices=concat_indices, sfreq=sfreq, + gc_n_lags=10) + assert not np.all(con_st.get_data() == con_ts.get_data()) + assert np.all(con_st.get_data()[:, 0] == con_st_ts.get_data()[:, 0]) + assert np.all(con_ts.get_data()[:, 0] == con_st_ts.get_data()[:, 1]) + + @ pytest.mark.parametrize('kind', ('epochs', 'ndarray', 'stc', 'combo')) def test_epochs_tmin_tmax(kind): """Test spectral.spectral_connectivity_epochs with epochs and arrays.""" @@ -857,7 +935,7 @@ def test_spectral_connectivity_time_delayed(): trans_bandwidth = 2.0 # Hz delay = 5 # samples (non-zero delay needed for GC to be >> 0) - indices = tuple([np.arange(n_seeds), np.arange(n_seeds) + n_seeds]) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) # 20-30 Hz connectivity fstart, fend = 20.0, 30.0 @@ -1061,7 +1139,8 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): n_times = 500 rng = np.random.RandomState(0) data = rng.randn(n_epochs, n_signals, n_times) - indices = (np.arange(0, 2), np.arange(2, 4)) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) + n_cons = len(indices[0]) freqs = np.arange(10, 25 + 1) con_shape = [1] @@ -1080,34 +1159,60 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): # check shape of MIC patterns are correct if method == 'mic': - patterns_shape = [len(indices[0])] - if faverage: - patterns_shape.append(1) - else: - patterns_shape.append(len(freqs)) - if not average: - patterns_shape = [n_epochs, *patterns_shape] - patterns_shape = [2, *patterns_shape] - assert np.array(con.attrs['patterns']).shape == tuple(patterns_shape) + for indices_type in ['full', 'ragged']: + if indices_type == 'full': + indices = (np.array([[0, 1]]), np.array([[2, 3]])) + else: + indices = (np.array([[0, 1]]), np.array([[2]])) + max_n_chans = 2 + patterns_shape = [n_cons, max_n_chans] + if faverage: + patterns_shape.append(1) + else: + patterns_shape.append(len(freqs)) + if not average: + patterns_shape = [n_epochs, *patterns_shape] + patterns_shape = [2, *patterns_shape] + + con = spectral_connectivity_time( + data, freqs, indices=indices, method=method, sfreq=sfreq, + faverage=faverage, average=average, gc_n_lags=10) + patterns = np.array(con.attrs['patterns']) + # 2 (x epochs) x cons x channels x freqs|fbands + assert (patterns.shape == tuple(patterns_shape)) + if indices_type == 'ragged': + assert not np.any(np.isnan(patterns[0, ..., :, :])) + assert not np.any(np.isnan(patterns[0, ..., 0, :])) + assert np.all(np.isnan(patterns[1, ..., 1, :])) # padded entry + assert np.all(np.array(con.indices) == np.array( + (np.array([[0, 1]]), np.array([[2, -1]])))) -@pytest.mark.parametrize( - 'method', ['mic', 'mim', 'gc', 'gc_tr']) -def test_multivar_spectral_connectivity_time_error_catch(method): + +@pytest.mark.parametrize('method', ['mic', 'mim', 'gc', 'gc_tr']) +@pytest.mark.parametrize('mode', ['multitaper', 'cwt_morlet']) +def test_multivar_spectral_connectivity_time_error_catch(method, mode): """Test error catching for time-resolved multivar. connectivity methods.""" sfreq = 50. n_signals = 4 # Do not change! n_epochs = 8 n_times = 256 data = np.random.rand(n_epochs, n_signals, n_times) - indices = (np.arange(0, 2), np.arange(2, 4)) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) freqs = np.arange(10, 25 + 1) - # check bad indices with repeated channels + # check bad indices without nested array caught + with pytest.raises(TypeError, + match='multivariate indices must contain array-likes'): + non_nested_indices = (np.array([0, 1]), np.array([2, 3])) + spectral_connectivity_time(data, freqs, method=method, mode=mode, + indices=non_nested_indices, sfreq=sfreq) + + # check bad indices with repeated channels caught with pytest.raises(ValueError, - match='seed and target indices cannot contain'): - repeated_indices = tuple([[0, 1, 1], [2, 2, 3]]) - spectral_connectivity_time(data, freqs, method=method, + match='multivariate indices cannot contain repeated'): + repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) + spectral_connectivity_time(data, freqs, method=method, mode=mode, indices=repeated_indices, sfreq=sfreq) # check mixed methods caught @@ -1115,7 +1220,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method): match='bivariate and multivariate connectivity'): mixed_methods = [method, 'coh'] spectral_connectivity_time(data, freqs, method=mixed_methods, - indices=indices, sfreq=sfreq) + mode=mode, indices=indices, sfreq=sfreq) # check bad rank args caught too_low_rank = (np.array([0]), np.array([0])) @@ -1123,38 +1228,40 @@ def test_multivar_spectral_connectivity_time_error_catch(method): match='ranks for seeds and targets must be'): spectral_connectivity_time( data, freqs, method=method, indices=indices, sfreq=sfreq, - rank=too_low_rank) + mode=mode, rank=too_low_rank) too_high_rank = (np.array([3]), np.array([3])) with pytest.raises(ValueError, match='ranks for seeds and targets must be'): spectral_connectivity_time( data, freqs, method=method, indices=indices, sfreq=sfreq, - rank=too_high_rank) + mode=mode, rank=too_high_rank) # check all-to-all conn. computed for MIC/MIM when no indices given if method in ['mic', 'mim']: - con = spectral_connectivity_epochs( - data, freqs, method=method, indices=None, sfreq=sfreq) - assert (np.array(con.indices).tolist() == - [[[0, 1, 2, 3]], [[0, 1, 2, 3]]]) + con = spectral_connectivity_time( + data, freqs, method=method, indices=None, sfreq=sfreq, mode=mode) + assert con.indices is None + assert con.n_nodes == n_signals + if method == 'mic': + assert np.array(con.attrs['patterns']).shape[3] == n_signals if method in ['gc', 'gc_tr']: # check no indices caught with pytest.raises(ValueError, match='indices must be specified'): - spectral_connectivity_time(data, freqs, method=method, + spectral_connectivity_time(data, freqs, method=method, mode=mode, indices=None, sfreq=sfreq) # check intersecting indices caught - bad_indices = (np.array([0, 1]), np.array([0, 2])) + bad_indices = (np.array([[0, 1]]), np.array([[0, 2]])) with pytest.raises(ValueError, match='seed and target indices must not intersect'): - spectral_connectivity_time(data, freqs, method=method, + spectral_connectivity_time(data, freqs, method=method, mode=mode, indices=bad_indices, sfreq=sfreq) # check bad fmin/fmax caught with pytest.raises(ValueError, match='computing Granger causality on multiple'): - spectral_connectivity_time(data, freqs, method=method, + spectral_connectivity_time(data, freqs, method=method, mode=mode, indices=indices, sfreq=sfreq, fmin=(5., 15.), fmax=(15., 30.)) @@ -1162,7 +1269,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method): def test_save(tmp_path): """Test saving results of spectral connectivity.""" rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq, f = 10, 2, 2000, 1000., 20. + n_epochs, n_chs, n_times, sfreq, f = 10, 3, 2000, 1000., 20. data = rng.randn(n_epochs, n_chs, n_times) sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000) data[:, :, 500:1500] += sig @@ -1174,3 +1281,131 @@ def test_save(tmp_path): epochs, fmin=(4, 8, 13, 30), fmax=(8, 13, 30, 45), faverage=True) conn.save(tmp_path / 'foo.nc') + + +def test_multivar_save_load(tmp_path): + """Test saving and loading results of multivariate connectivity.""" + rng = np.random.RandomState(0) + n_epochs, n_chs, n_times, sfreq, f = 5, 4, 2000, 1000., 20. + data = rng.randn(n_epochs, n_chs, n_times) + sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000) + data[:, :, 500:1500] += sig + info = create_info(n_chs, sfreq, 'eeg') + tmin = -1 + epochs = EpochsArray(data, info, tmin=tmin) + tmp_file = os.path.join(tmp_path, 'foo_mvc.nc') + + non_ragged_indices = (np.array([[0, 1]]), np.array([[2, 3]])) + ragged_indices = (np.array([[0, 1]]), np.array([[2]])) + for indices in [non_ragged_indices, ragged_indices]: + con = spectral_connectivity_epochs( + epochs, method=['mic', 'mim', 'gc', 'gc_tr'], indices=indices, + sfreq=sfreq, fmin=10, fmax=30) + for this_con in con: + this_con.save(tmp_file) + read_con = read_connectivity(tmp_file) + assert_array_almost_equal(this_con.get_data(), + read_con.get_data('raveled')) + if this_con.attrs['patterns'] is not None: + assert_array_almost_equal(np.array(this_con.attrs['patterns']), + np.array(read_con.attrs['patterns'])) + # split `repr` before the file size (`~23 kB` for example) + a = repr(this_con).split('~')[0] + b = repr(read_con).split('~')[0] + assert a == b + + +@pytest.mark.parametrize("method", ['coh', 'plv', 'pli', 'wpli', 'ciplv']) +@pytest.mark.parametrize("indices", [None, + (np.array([0, 1]), np.array([2, 3]))]) +def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): + """Test that indices values and type is maintained after saving. + + If `indices` is None, `indices` in the returned connectivity object should + be None, otherwise, `indices` should be a tuple. The type of `indices` and + its values should be retained after saving and reloading. + """ + rng = np.random.RandomState(0) + n_epochs, n_chs, n_times, sfreq = 5, 4, 200, 100.0 + data = rng.randn(n_epochs, n_chs, n_times) + info = create_info(n_chs, sfreq, "eeg") + tmin = -1 + epochs = EpochsArray(data, info, tmin=tmin) + freqs = np.arange(10, 31) + tmp_file = os.path.join(tmp_path, "foo_mvc.nc") + + # test the pair of method and indices defined to check the output indices + con_epochs = spectral_connectivity_epochs( + epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30 + ) + con_time = spectral_connectivity_time( + epochs, freqs, method=method, indices=indices, sfreq=sfreq + ) + + for con in [con_epochs, con_time]: + con.save(tmp_file) + read_con = read_connectivity(tmp_file) + + if indices is not None: + # check indices of same type (tuples) + assert isinstance(con.indices, tuple) and isinstance( + read_con.indices, tuple + ) + # check indices have same values + assert np.all(np.array(con.indices) == np.array(read_con.indices)) + else: + assert con.indices is None and read_con.indices is None + + +@pytest.mark.parametrize("method", ['mic', 'mim', 'gc', 'gc_tr']) +@pytest.mark.parametrize("indices", [None, + (np.array([[0, 1]]), np.array([[2, 3]]))]) +def test_multivar_spectral_connectivity_indices_roundtrip_io( + tmp_path, method, indices +): + """Test that indices values and type is maintained after saving. + + If `indices` is None, `indices` in the returned connectivity object should + be None, otherwise, `indices` should be a tuple. The type of `indices` and + its values should be retained after saving and reloading. + """ + rng = np.random.RandomState(0) + n_epochs, n_chs, n_times, sfreq = 5, 4, 200, 100.0 + data = rng.randn(n_epochs, n_chs, n_times) + info = create_info(n_chs, sfreq, "eeg") + tmin = -1 + epochs = EpochsArray(data, info, tmin=tmin) + freqs = np.arange(10, 31) + tmp_file = os.path.join(tmp_path, "foo_mvc.nc") + + # test the pair of method and indices defined to check the output indices + if indices is None and method in ['gc', 'gc_tr']: + # indicesmust be specified for GC + pytest.skip() + + con_epochs = spectral_connectivity_epochs( + epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30, + gc_n_lags=10 + ) + con_time = spectral_connectivity_time( + epochs, freqs, method=method, indices=indices, sfreq=sfreq, + gc_n_lags=10 + ) + + for con in [con_epochs, con_time]: + con.save(tmp_file) + read_con = read_connectivity(tmp_file) + + if indices is not None: + # check indices of same type (tuples) + assert isinstance(con.indices, tuple) and isinstance( + read_con.indices, tuple + ) + # check indices are masked + assert all([np.ma.isMA(inds) for inds in con.indices] and + [np.ma.isMA(inds) for inds in read_con.indices]) + # check indices have same values + assert np.all([con_inds == read_inds for con_inds, read_inds in + zip(con.indices, read_con.indices)]) + else: + assert con.indices is None and read_con.indices is None diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 7c1aabe6..07dc4e57 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -16,7 +16,7 @@ from .epochs import (_MICEst, _MIMEst, _GCEst, _GCTREst, _compute_freq_mask, _check_rank_input) from .smooth import _create_kernel, _smooth_spectra -from ..utils import check_indices, fill_doc +from ..utils import check_indices, _check_multivariate_indices, fill_doc _multivariate_methods = ['mic', 'mim', 'gc', 'gc_tr'] @@ -70,10 +70,13 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, :class:`EpochSpectralConnectivity`. indices : tuple of array_like | None Two arrays with indices of connections for which to compute - connectivity. If a multivariate method is called, the indices are for a - single connection between all seeds and all targets. If None, all - connections are computed, unless a Granger causality method is called, - in which case an error is raised. + connectivity. If a bivariate method is called, each array for the seeds + and targets should contain the channel indices for the each bivariate + connection. If a multivariate method is called, each array for the + seeds and targets should consist of nested arrays containing + the channel indices for each multivariate connection. If None, + connections between all channels are computed, unless a Granger + causality method is called, in which case an error is raised. sfreq : float The sampling frequency. Required if data is not :class:`Epochs `. @@ -144,11 +147,11 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, :class:`EpochSpectralConnectivity`, :class:`SpectralConnectivity` or a list of instances corresponding to connectivity measures if several connectivity measures are specified. - The shape of each connectivity dataset is - (n_epochs, n_signals, n_signals, n_freqs) when ``indices`` is `None`, - (n_epochs, n_nodes, n_nodes, n_freqs) when ``indices`` is specified - and ``n_nodes = len(indices[0])``, or (n_epochs, 1, 1, n_freqs) when a - multi-variate method is called regardless of "indices". + The shape of each connectivity dataset is (n_epochs, n_cons, n_freqs). + When "indices" is None and a bivariate method is called, + "n_cons = n_signals ** 2", or if a multivariate method is called + "n_cons = 1". When "indices" is specified, "n_con = len(indices[0])" + for bivariate and multivariate methods. See Also -------- @@ -202,13 +205,19 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, scores are in the same order as defined indices. For multivariate methods, this is handled differently. If "indices" is - None, connectivity between all signals will attempt to be computed (this is - not possible if a Granger causality method is called). If "indices" is - specified, the seeds and targets are treated as a single connection. For - example, to compute the connectivity between signals 0, 1, 2 and 3, 4, 5, - one would use the same approach as above, however the signals would all be - considered for a single connection and the connectivity scores would have - the shape (1, n_freqs). + None, connectivity between all signals will be computed and a single + connectivity spectrum will be returned (this is not possible if a Granger + causality method is called). If "indices" is specified, seed and target + indices for each connection should be specified as nested array-likes. For + example, to compute the connectivity between signals (0, 1) -> (2, 3) and + (0, 1) -> (4, 5), indices should be specified as:: + + indices = (np.array([[0, 1], [0, 1]]), # seeds + np.array([[2, 3], [4, 5]])) # targets + + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. **Supported Connectivity Measures** @@ -398,36 +407,54 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, 'indices must be specified when computing Granger ' 'causality, as all-to-all connectivity is not supported') logger.info('using all indices for multivariate connectivity') - indices_use = (np.arange(n_signals, dtype=int), - np.arange(n_signals, dtype=int)) + # indices expected to be a masked array, even if not ragged + indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], + np.arange(n_signals, dtype=int)[np.newaxis, :]) + indices_use = np.ma.masked_array(indices_use, + mask=False, fill_value=-1) else: logger.info('only using indices for lower-triangular matrix') indices_use = np.tril_indices(n_signals, k=-1) else: if multivariate_con: - if ( - len(np.unique(indices[0])) != len(indices[0]) or - len(np.unique(indices[1])) != len(indices[1]) - ): - raise ValueError( - 'seed and target indices cannot contain repeated ' - 'channels for multivariate connectivity') + # pad ragged indices and mask the invalid entries + indices_use = _check_multivariate_indices(indices, n_signals) if any(this_method in _gc_methods for this_method in method): - if set(indices[0]).intersection(indices[1]): - raise ValueError( - 'seed and target indices must not intersect when ' - 'computing Granger causality') - indices_use = check_indices(indices) - source_idx = indices_use[0] - target_idx = indices_use[1] - n_pairs = len(source_idx) if not multivariate_con else 1 + for seed, target in zip(indices_use[0], indices_use[1]): + intersection = np.intersect1d(seed.compressed(), + target.compressed()) + if intersection.size > 0: + raise ValueError( + 'seed and target indices must not intersect when ' + 'computing Granger causality') + # make sure padded indices are stored in the connectivity object + # create a copy so that `indices_use` can be modified + indices = (indices_use[0].copy(), indices_use[1].copy()) + else: + indices_use = check_indices(indices) + n_cons = len(indices_use[0]) # unique signals for which we actually need to compute the CSD of - signals_use = np.unique(np.r_[indices_use[0], indices_use[1]]) + if multivariate_con: + signals_use = np.unique(indices_use.compressed()) + remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(signals_use)} + remapped_inds = indices_use.copy() + # multivariate functions expect seed/target remapping + for idx in signals_use: + remapped_inds[indices_use == idx] = remapping[idx] + source_idx = remapped_inds[0] + target_idx = remapped_inds[1] + max_n_channels = len(indices_use[0][0]) + else: + # no indices remapping required for bivariate functions + signals_use = np.unique(np.r_[indices_use[0], indices_use[1]]) + source_idx = indices_use[0].copy() + target_idx = indices_use[1].copy() + max_n_channels = len(indices_use[0]) # check rank input and compute data ranks if necessary if multivariate_con: - rank = _check_rank_input(rank, data, sfreq, indices_use) + rank = _check_rank_input(rank, data, indices_use) else: rank = None gc_n_lags = None @@ -479,9 +506,10 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, conn = dict() conn_patterns = dict() for m in method: - conn[m] = np.zeros((n_epochs, n_pairs, n_freqs)) - conn_patterns[m] = np.full((n_epochs, 2, len(source_idx), n_freqs), - np.nan) + conn[m] = np.zeros((n_epochs, n_cons, n_freqs)) + # patterns shape of [epochs x seeds/targets x cons x channels x freqs] + conn_patterns[m] = np.full( + (n_epochs, 2, n_cons, max_n_channels, n_freqs), np.nan) logger.info('Connectivity computation...') # parameters to pass to the connectivity function @@ -505,8 +533,8 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, if np.isnan(conn_patterns[m]).all(): conn_patterns[m] = None else: - # epochs x 2 x n_channels x n_freqs - conn_patterns[m] = conn_patterns[m].transpose((1, 0, 2, 3)) + # transpose to [seeds/targets x epochs x cons x channels x freqs] + conn_patterns[m] = conn_patterns[m].transpose((1, 0, 2, 3, 4)) if indices is None and not multivariate_con: conn_flat = conn @@ -520,11 +548,6 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, conn_flat[m].shape[2:]) conn[m] = this_conn - if multivariate_con: - # UNTIL RAGGED ARRAYS SUPPORTED - indices = tuple( - [[np.array(indices_use[0])], [np.array(indices_use[1])]]) - # create the connectivity containers out = [] for m in method: @@ -569,9 +592,9 @@ def _spectral_connectivity(data, method, kernel, foi_idx, Smoothing kernel. foi_idx : array_like, shape (n_foi, 2) Upper and lower bound indices of frequency bands. - source_idx : array_like, shape (n_pairs,) + source_idx : array_like, shape (n_cons,) or (n_cons, n_channels) Defines the signal pairs of interest together with ``target_idx``. - target_idx : array_like, shape (n_pairs,) + target_idx : array_like, shape (n_cons,) or (n_cons, n_channels) Defines the signal pairs of interest together with ``source_idx``. signals_use : list of int The unique signals on which connectivity is to be computed. @@ -608,8 +631,8 @@ def _spectral_connectivity(data, method, kernel, foi_idx, ------- scores : dict Dictionary containing the connectivity estimates corresponding to the - metrics in ``method``. Each element is an array of shape (n_pairs, - n_freqs) or (n_pairs, n_fbands) if ``faverage`` is `True`. + metrics in ``method``. Each element is an array of shape (n_cons, + n_freqs) or (n_cons, n_fbands) if ``faverage`` is `True`. patterns : dict Dictionary containing the connectivity patterns (for reconstructing the @@ -619,7 +642,7 @@ def _spectral_connectivity(data, method, kernel, foi_idx, or (2, n_channels, 1) if ``faverage`` is `True`, where 2 corresponds to the seed and target signals (respectively). """ - n_pairs = len(source_idx) + n_cons = len(source_idx) data = np.expand_dims(data, axis=0) if mode == 'cwt_morlet': out = tfr_array_morlet( @@ -665,12 +688,12 @@ def _spectral_connectivity(data, method, kernel, foi_idx, scores = {} patterns = {} conn = _parallel_con(out, method, kernel, foi_idx, source_idx, target_idx, - signals_use, gc_n_lags, rank, n_jobs, verbose, - n_pairs, faverage, weights, multivariate_con) + signals_use, gc_n_lags, rank, n_jobs, verbose, n_cons, + faverage, weights, multivariate_con) for i, m in enumerate(method): if multivariate_con: scores[m] = conn[0][i] - patterns[m] = conn[1][i][:, 0] if conn[1][i] is not None else None + patterns[m] = conn[1][i] if conn[1][i] is not None else None else: scores[m] = [out[i] for out in conn] patterns[m] = None @@ -699,16 +722,16 @@ def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, Smoothing kernel. foi_idx : array_like, shape (n_foi, 2) Upper and lower bound indices of frequency bands. - source_idx : array_like, shape (n_pairs,) + source_idx : array_like, shape (n_cons,) or (n_cons, n_channels) Defines the signal pairs of interest together with ``target_idx``. - target_idx : array_like, shape (n_pairs,) + target_idx : array_like, shape (n_cons,) or (n_cons, n_channels) Defines the signal pairs of interest together with ``source_idx``. signals_use : list of int The unique signals on which connectivity is to be computed. gc_n_lags : int Number of lags to use for the vector autoregressive model when computing Granger causality. - rank : tuple of array + rank : tuple of array of int Ranks to project the seed and target data to. n_jobs : int Number of parallel jobs. @@ -825,18 +848,23 @@ def _pairwise_con(w, psd, x, y, method, kernel, foi_idx, return out -def _multivariate_con(w, source_idx, target_idx, signals_use, method, kernel, - foi_idx, faverage, weights, gc_n_lags, rank, n_jobs): +def _multivariate_con(w, seeds, targets, signals_use, method, kernel, foi_idx, + faverage, weights, gc_n_lags, rank, n_jobs): """Compute spectral connectivity metrics between multiple signals. Parameters ---------- w : array_like, shape (n_chans, n_tapers, n_freqs, n_times) Time-frequency data. - x : int - Channel index. - y : int - Channel index. + seeds : array, shape of (n_cons, n_channels) + Seed channel indices. ``n_channels`` is the largest number of channels + across all connections, with missing entries padded with ``-1``. + targets : array, shape of (n_cons, n_channels) + Target channel indices. ``n_channels`` is the largest number of + channels across all connections, with missing entries padded with + ``-1``. + signals_use : list of int + The unique signals on which connectivity is to be computed. method : str Connectivity method. kernel : array_like, shape (n_sm_fres, n_sm_times) @@ -847,6 +875,13 @@ def _multivariate_con(w, source_idx, target_idx, signals_use, method, kernel, Average over frequency bands. weights : array_like, shape (n_tapers, n_freqs, n_times) | None Multitaper weights. + gc_n_lags : int + Number of lags to use for the vector autoregressive model when + computing Granger causality. + rank : tuple of array, shape of (2, n_cons) + Ranks to project the seed and target data to. + n_jobs : int + Number of jobs to run in parallel. Returns ------- @@ -859,8 +894,10 @@ def _multivariate_con(w, source_idx, target_idx, signals_use, method, kernel, List of connectivity patterns between seed and target signals for each connectivity method. Each element is an array of length 2 corresponding to the seed and target patterns, respectively, each with shape - (n_channels, n_freqs,) or (n_channels, n_fbands) depending on - ``faverage``. + (n_channels, n_freqs) or (n_channels, n_fbands) + depending on ``faverage``. ``n_channels`` is the largest number of + channels across all connections, with missing entries padded with + ``np.nan``. """ csd = [] for x in signals_use: @@ -880,8 +917,7 @@ def _multivariate_con(w, source_idx, target_idx, signals_use, method, kernel, 'gc_tr': _GCTREst} conn = [] for m in method: - # N_CONS = 1 UNTIL RAGGED ARRAYS SUPPORTED - call_params = {'n_signals': len(signals_use), 'n_cons': 1, + call_params = {'n_signals': len(signals_use), 'n_cons': len(seeds), 'n_freqs': csd.shape[1], 'n_times': 0, 'n_jobs': n_jobs} if m in _gc_methods: @@ -895,7 +931,7 @@ def _multivariate_con(w, source_idx, target_idx, signals_use, method, kernel, scores = [] patterns = [] for con_est in conn: - con_est.compute_con(np.array([source_idx, target_idx]), rank) + con_est.compute_con((seeds, targets), rank) scores.append(con_est.con_scores[..., np.newaxis]) patterns.append(con_est.patterns) if patterns[-1] is not None: diff --git a/mne_connectivity/tests/test_utils.py b/mne_connectivity/tests/test_utils.py index c012481c..d80f52a6 100644 --- a/mne_connectivity/tests/test_utils.py +++ b/mne_connectivity/tests/test_utils.py @@ -3,11 +3,15 @@ from numpy.testing import assert_array_equal from mne_connectivity import Connectivity -from mne_connectivity.utils import degree, seed_target_indices +from mne_connectivity.utils import (degree, check_indices, + _check_multivariate_indices, + seed_target_indices, + seed_target_multivariate_indices) -def test_indices(): - """Test connectivity indexing methods.""" +def test_seed_target_indices(): + """Test indices generation functions.""" + # bivariate indices n_seeds_test = [1, 3, 4] n_targets_test = [2, 3, 200] rng = np.random.RandomState(42) @@ -25,6 +29,125 @@ def test_indices(): for target in targets: assert np.sum(indices[1] == target) == n_seeds + # multivariate indices + # non-ragged indices + seeds = [[0, 1]] + targets = [[2, 3], [3, 4]] + indices = seed_target_multivariate_indices(seeds, targets) + match_indices = (np.array([[0, 1], [0, 1]], dtype=object), + np.array([[2, 3], [3, 4]], dtype=object)) + for type_i in range(2): + for con_i in range(len(indices[0])): + assert np.all(indices[type_i][con_i] == + match_indices[type_i][con_i]) + # ragged indices + seeds = [[0, 1]] + targets = [[2, 3, 4], [4]] + indices = seed_target_multivariate_indices(seeds, targets) + match_indices = (np.array([[0, 1], [0, 1]], dtype=object), + np.array([[2, 3, 4], [4]], dtype=object)) + for type_i in range(2): + for con_i in range(len(indices[0])): + assert np.all(indices[type_i][con_i] == + match_indices[type_i][con_i]) + # test error catching + # non-array-like seeds/targets + with pytest.raises(TypeError, + match='`seeds` and `targets` must be array-like'): + seed_target_multivariate_indices(0, 1) + # non-nested seeds/targets + with pytest.raises(TypeError, + match='`seeds` and `targets` must contain nested'): + seed_target_multivariate_indices([0], [1]) + # repeated seeds/targets + with pytest.raises(ValueError, + match='`seeds` and `targets` cannot contain repeated'): + seed_target_multivariate_indices([[0, 1, 1]], [[2, 2, 3]]) + + +def test_check_indices(): + """Test check_indices function.""" + # bivariate indices + # test error catching + with pytest.raises(ValueError, + match='indices must be a tuple of length 2'): + non_tuple_indices = [[0], [1]] + check_indices(non_tuple_indices) + with pytest.raises(ValueError, + match='indices must be a tuple of length 2'): + non_len2_indices = ([0], [1], [2]) + check_indices(non_len2_indices) + with pytest.raises(ValueError, match='Index arrays indices'): + non_equal_len_indices = ([0], [1, 2]) + check_indices(non_equal_len_indices) + with pytest.raises(TypeError, + match='Channel indices must be integers, not array'): + nested_indices = ([[0]], [[1]]) + check_indices(nested_indices) + + +def test_check_multivariate_indices(): + """Test _check_multivariate_indices function.""" + n_signals = 5 + mask_value = -1 + # non-ragged indices + seeds = [[0, 1], [0, 1]] + targets = [[2, 3], [3, 4]] + indices = _check_multivariate_indices((seeds, targets), n_signals) + assert np.ma.isMA(indices) + assert indices.fill_value == mask_value + assert np.all(indices == np.array([[[0, 1], [0, 1]], + [[2, 3], [3, 4]]])) + # non-ragged indices with negative values + seeds = [[0, 1], [0, 1]] + targets = [[2, 3], [3, -1]] + indices = _check_multivariate_indices((seeds, targets), n_signals) + assert np.ma.isMA(indices) + assert indices.fill_value == mask_value + assert np.all(indices == np.array([[[0, 1], [0, 1]], + [[2, 3], [3, 4]]])) + # ragged indices + seeds = [[0, 1], [0, 1]] + targets = [[2, 3, 4], [4]] + indices = _check_multivariate_indices((seeds, targets), n_signals) + assert np.ma.isMA(indices) + assert indices.fill_value == mask_value + assert np.all(indices == np.array([[[0, 1, -1], [0, 1, -1]], + [[2, 3, 4], [4, -1, -1]]])) + # ragged indices with negative values + seeds = [[0, 1], [0, 1]] + targets = [[2, 3, 4], [-1]] + indices = _check_multivariate_indices((seeds, targets), n_signals) + assert np.ma.isMA(indices) + assert indices.fill_value == mask_value + assert np.all(indices == np.array([[[0, 1, -1], [0, 1, -1]], + [[2, 3, 4], [4, -1, -1]]])) + + # test error catching + with pytest.raises(ValueError, + match='indices must be a tuple of length 2'): + non_tuple_indices = [np.array([0, 1]), np.array([2, 3])] + _check_multivariate_indices(non_tuple_indices, n_signals) + with pytest.raises(ValueError, + match='indices must be a tuple of length 2'): + non_len2_indices = (np.array([0]), np.array([1]), np.array([2])) + _check_multivariate_indices(non_len2_indices, n_signals) + with pytest.raises(ValueError, match='index arrays indices'): + non_equal_len_indices = (np.array([[0]]), np.array([[1], [2]])) + _check_multivariate_indices(non_equal_len_indices, n_signals) + with pytest.raises(TypeError, + match='multivariate indices must contain array-likes'): + non_nested_indices = (np.array([0, 1]), np.array([2, 3])) + _check_multivariate_indices(non_nested_indices, n_signals) + with pytest.raises(ValueError, + match='multivariate indices cannot contain repeated'): + repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) + _check_multivariate_indices(repeated_indices, n_signals) + with pytest.raises(ValueError, + match='a negative channel index is not present in the'): + missing_chan_indices = (np.array([[0, 1]]), np.array([[2, -5]])) + _check_multivariate_indices(missing_chan_indices, n_signals) + def test_degree(): """Test degree function.""" diff --git a/mne_connectivity/utils/__init__.py b/mne_connectivity/utils/__init__.py index e82f054b..171711b8 100644 --- a/mne_connectivity/utils/__init__.py +++ b/mne_connectivity/utils/__init__.py @@ -1,3 +1,4 @@ from .docs import fill_doc -from .utils import (check_indices, degree, seed_target_indices, +from .utils import (check_indices, _check_multivariate_indices, degree, + seed_target_indices, seed_target_multivariate_indices, parallel_loop, _prepare_xarray_mne_data_structures) diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index 5ae94acb..1f278400 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -1,4 +1,5 @@ # Authors: Martin Luessi +# Thomas S. Binns # # License: BSD (3-clause) import numpy as np @@ -49,17 +50,24 @@ def par(x): def check_indices(indices): - """Check indices parameter. + """Check indices parameter for bivariate connectivity. Parameters ---------- - indices : tuple of array - Tuple of length 2 containing index pairs. + indices : tuple of array of int, shape (2, n_cons) + Tuple containing index pairs. Returns ------- - indices : tuple of array + indices : tuple of array of int, shape (2, n_cons) The indices. + + Notes + ----- + Indices for bivariate connectivity should be a tuple of length 2, + containing the channel indices for the seed and target channel pairs, + respectively. Seed and target indices should be equal-length array-likes of + integers representing the indices of the individual channels in the data. """ if not isinstance(indices, tuple) or len(indices) != 2: raise ValueError('indices must be a tuple of length 2') @@ -68,23 +76,149 @@ def check_indices(indices): raise ValueError('Index arrays indices[0] and indices[1] must ' 'have the same length') + if any(isinstance(inds, (np.ndarray, list, tuple)) for inds in + [*indices[0], *indices[1]]): + raise TypeError('Channel indices must be integers, not array-likes') + return indices +def _check_multivariate_indices(indices, n_chans): + """Check indices parameter for multivariate connectivity and mask it. + + Parameters + ---------- + indices : tuple of array of array of int, shape (2, n_cons, variable) + Tuple containing index sets. + + n_chans : int + The number of channels in the data. Used when converting negative + indices to positive indices. + + Returns + ------- + indices : array of array of int, shape of (2, n_cons, max_n_chans) + The padded indices as a masked array. + + Notes + ----- + Indices for multivariate connectivity should be a tuple of length 2 + containing the channel indices for the seed and target channel sets, + respectively. Seed and target indices should be equal-length array-likes + representing the indices of the channel sets in the data for each + connection. The indices for each connection should be an array-like of + integers representing the individual channels in the data. The length of + indices for each connection do not need to be equal. All indices within a + connection must be unique. + + If the seed and target indices are given as lists or tuples, they will be + converted to numpy arrays. Because the number of channels can differ across + connections or between the seeds and targets for a given connection (i.e. + ragged/jagged indices), the returned array will be padded out to a 'full' + array with an invalid index (``-1``) according to the maximum number of + channels in the seed or target of any one connection. These invalid + entries are then masked and returned as numpy masked arrays. E.g. the + ragged indices of shape ``(2, n_cons, variable)``:: + + indices = ([[0, 1], [0, 1 ]], # seeds + [[2, 3], [4, 5, 6]]) # targets + + would be padded to full arrays:: + + indices = ([[0, 1, -1], [0, 1, -1]], # seeds + [[2, 3, -1], [4, 5, 6]]) # targets + + to have shape ``(2, n_cons, max_n_chans)``, where ``max_n_chans = 3``. The + invalid entries are then masked:: + + indices = ([[0, 1, --], [0, 1, --]], # seeds + [[2, 3, --], [4, 5, 6]]) # targets + + In case "indices" contains negative values to index channels, these will be + converted to the corresponding positive-valued index before any masking is + applied. + + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. + """ + if not isinstance(indices, tuple) or len(indices) != 2: + raise ValueError('indices must be a tuple of length 2') + + if len(indices[0]) != len(indices[1]): + raise ValueError('index arrays indices[0] and indices[1] must ' + 'have the same length') + + n_cons = len(indices[0]) + invalid = -1 + + max_n_chans = 0 + for group_idx, group in enumerate(indices): + for con_idx, con in enumerate(group): + if not isinstance(con, (np.ndarray, list, tuple)): + raise TypeError( + 'multivariate indices must contain array-likes of channel ' + 'indices for each seed and target') + con = np.array(con) + if len(con) != len(np.unique(con)): + raise ValueError( + 'multivariate indices cannot contain repeated channels ' + 'within a seed or target') + max_n_chans = max(max_n_chans, len(con)) + # convert negative to positive indices + for chan_idx, chan in enumerate(con): + if chan < 0: + if chan * -1 >= n_chans: + raise ValueError( + 'a negative channel index is not present in the ' + 'data' + ) + indices[group_idx][con_idx][chan_idx] = chan % n_chans + + # pad indices to avoid ragged arrays + padded_indices = np.full((2, n_cons, max_n_chans), invalid, dtype=np.int32) + for con_i, (seed, target) in enumerate(zip(indices[0], indices[1])): + padded_indices[0, con_i, :len(seed)] = seed + padded_indices[1, con_i, :len(target)] = target + + # mask invalid indices + masked_indices = np.ma.masked_values(padded_indices, invalid) + + return masked_indices + + def seed_target_indices(seeds, targets): - """Generate indices parameter for seed based connectivity analysis. + """Generate indices parameter for bivariate seed-based connectivity. Parameters ---------- - seeds : array of int | int + seeds : array of int | int, shape (n_unique_seeds) Seed indices. - targets : array of int | int + targets : array of int | int, shape (n_unique_targets) Indices of signals for which to compute connectivity. Returns ------- - indices : tuple of array + indices : tuple of array of int, shape (2, n_cons) The indices parameter used for connectivity computation. + + Notes + ----- + ``seeds`` and ``targets`` should be array-likes or integers representing + the indices of the channel pairs in the data for each connection. ``seeds`` + and ``targets`` will be expanded such that connectivity will be computed + between each seed and each target. E.g. the seeds and targets:: + + seeds = [0, 1] + targets = [2, 3, 4] + + would be returned as:: + + indices = (np.array([0, 0, 0, 1, 1, 1]), # seeds + np.array([2, 3, 4, 2, 3, 4])) # targets + + where the indices have been expanded to have shape ``(2, n_cons)``, where + ``n_cons = n_unique_seeds * n_unique_targets``. """ # make them arrays seeds = np.asarray((seeds,)).ravel() @@ -99,6 +233,77 @@ def seed_target_indices(seeds, targets): return indices +def seed_target_multivariate_indices(seeds, targets): + """Generate indices parameter for multivariate seed-based connectivity. + + Parameters + ---------- + seeds : array of array of int, shape (n_unique_seeds, variable) + Seed indices. + + targets : array of array of int, shape (n_unique_targets, variable) + Target indices. + + Returns + ------- + indices : tuple of array of array of int, shape (2, n_cons, variable) + The indices as a numpy object array. + + Notes + ----- + ``seeds`` and ``targets`` should be array-likes representing the indices of + the channel sets in the data for each connection. The indices for each + connection should be an array-like of integers representing the individual + channels in the data. The length of indices for each connection do not need + to be equal. Furthermore, all indices within a connection must be unique. + + Because the number of channels per connection can vary, the indices are + stored as numpy arrays with ``dtype=object``. E.g. ``seeds`` and + ``targets``:: + + seeds = [[0]] + targets = [[1, 2], [3, 4, 5]] + + would be returned as:: + + indices = (np.array([[0 ], [0 ]], dtype=object), # seeds + np.array([[1, 2], [3, 4, 5]], dtype=object)) # targets + + Even if the number of channels does not vary, the indices will still be + stored as object arrays for compatibility. + + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. + """ + array_like = (np.ndarray, list, tuple) + + if ( + not isinstance(seeds, array_like) or + not isinstance(targets, array_like) + ): + raise TypeError('`seeds` and `targets` must be array-like') + + for inds in [*seeds, *targets]: + if not isinstance(inds, array_like): + raise TypeError( + '`seeds` and `targets` must contain nested array-likes') + if len(inds) != len(np.unique(inds)): + raise ValueError( + '`seeds` and `targets` cannot contain repeated channels') + + indices = [[], []] + for seed in seeds: + for target in targets: + indices[0].append(np.array(seed)) + indices[1].append(np.array(target)) + + indices = (np.array(indices[0], dtype=object), + np.array(indices[1], dtype=object)) + + return indices + + def degree(connectivity, threshold_prop=0.2): """Compute the undirected degree of a connectivity matrix.