From d6ba398de507204c681a4a1c3de255d05020ac44 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 2 Nov 2023 16:13:43 +0100 Subject: [PATCH] switched to masked indices for multivariate conn --- examples/handling_ragged_arrays.py | 39 +++---- mne_connectivity/io.py | 6 +- .../spectral/tests/test_spectral.py | 15 +-- mne_connectivity/spectral/time.py | 15 ++- mne_connectivity/tests/test_utils.py | 16 ++- mne_connectivity/utils/utils.py | 105 +++++++++--------- 6 files changed, 94 insertions(+), 102 deletions(-) diff --git a/examples/handling_ragged_arrays.py b/examples/handling_ragged_arrays.py index 6ab01791..8076b9e2 100644 --- a/examples/handling_ragged_arrays.py +++ b/examples/handling_ragged_arrays.py @@ -62,8 +62,8 @@ # ragged_indices = (np.array([[0, 1 ], [0, 1, 2, 3]], dtype='object'), # np.array([[2, 3, 4], [4 ]], dtype='object')) # -# **N.B. Note that when forming ragged arrays in NumPy, dtype='object' must be -# specified.** +# **N.B. Note that since NumPy v1.19.0, dtype='object' must be specified when +# forming ragged arrays.** # # Just as for bivariate connectivity, the length of ``indices[0]`` and # ``indices[1]`` is equal (i.e. the number of connections), however information @@ -71,19 +71,16 @@ # array. Importantly, these indices are ragged, as the first connection will be # computed between 2 seed and 3 target channels, and the second connection # between 4 seed and 1 target channel. The connectivity functions will -# recognise the indices as being ragged, and pad them accordingly to make them -# easier to work with and compatible with the h5netcdf saving engine. The known -# value used to pad the arrays is ``-1``, an invalid channel index. The above -# indices would be padded to:: +# recognise the indices as being ragged, and pad them to a 'full' array by +# adding placeholder values which are masked accordingly. This makes the +# indices easier to work with, and also compatible with the engine used to save +# connectivity objects. For example, the above indices would become:: # -# padded_indices = (np.array([[0, 1, -1, -1], [0, 1, 2, 3]]), -# np.array([[2, 3, 4, -1], [4, -1, -1, -1]])) +# padded_indices = (np.array([[0, 1, --, --], [0, 1, 2, 3]]), +# np.array([[2, 3, 4, --], [4, --, --, --]])) # -# These indices are what is stored in the connectivity object, and is also the -# format of indices returned from the helper functions -# :func:`~mne_connectivity.check_multivariate_indices` and -# :func:`~mne_connectivity.seed_target_multivariate_indices`. It is also -# possible to pass the padded indices to the connectivity functions directly. +# where ``--`` are masked entries. These indices are what is stored in the +# returned connectivity objects. # # For the connectivity results themselves, the methods available in # MNE-Connectivity combine information across the different channels into a @@ -118,11 +115,11 @@ max_n_chans = max( len(inds) for inds in ([*ragged_indices[0], *ragged_indices[1]])) -# show that the padded indices entries are all -1 -assert np.count_nonzero(padded_indices[0][0] == -1) == 2 # 2 padded channels -assert np.count_nonzero(padded_indices[1][0] == -1) == 1 # 1 padded channels -assert np.count_nonzero(padded_indices[0][1] == -1) == 0 # 0 padded channels -assert np.count_nonzero(padded_indices[1][1] == -1) == 3 # 3 padded channels +# show that the padded indices entries are masked +assert np.sum(padded_indices[0][0].mask) == 2 # 2 padded channels +assert np.sum(padded_indices[1][0].mask) == 1 # 1 padded channels +assert np.sum(padded_indices[0][1].mask) == 0 # 0 padded channels +assert np.sum(padded_indices[1][1].mask) == 3 # 3 padded channels # patterns have shape [seeds/targets x cons x max channels x freqs (x times)] assert patterns.shape == (2, n_cons, max_n_chans, n_freqs) @@ -137,11 +134,11 @@ seed_patterns_con1 = patterns[0, 0, :len(ragged_indices[0][0])] target_patterns_con1 = patterns[1, 0, :len(ragged_indices[1][0])] -# extract patterns for second connection using the padded indices (pad = -1) +# extract patterns for second connection using the padded, masked indices seed_patterns_con2 = ( - patterns[0, 1, :np.count_nonzero(padded_indices[0][1] != -1)]) + patterns[0, 1, :padded_indices[0][1].count()]) target_patterns_con2 = ( - patterns[1, 1, :np.count_nonzero(padded_indices[1][1] != -1)]) + patterns[1, 1, :padded_indices[1][1].count()]) # show that shapes of patterns are correct assert seed_patterns_con1.shape == (2, n_freqs) # channels (0, 1) diff --git a/mne_connectivity/io.py b/mne_connectivity/io.py index 63aa3501..e8d9b916 100644 --- a/mne_connectivity/io.py +++ b/mne_connectivity/io.py @@ -53,9 +53,11 @@ def _xarray_to_conn(array, cls_func): event_id = dict(zip(event_id_keys, event_id_vals)) array.attrs['event_id'] = event_id - # convert indices numpy arrays to a tuple of arrays + # convert indices numpy arrays to a tuple of masked arrays + # (only multivariate connectivity indices saved as arrays) if isinstance(array.attrs['indices'], np.ndarray): - array.attrs['indices'] = tuple(array.attrs['indices']) + array.attrs['indices'] = tuple( + np.ma.masked_values(array.attrs['indices'], -1)) # create the connectivity class conn = cls_func( diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 273fa2d5..a514382b 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -1315,12 +1315,9 @@ def test_multivar_save_load(tmp_path): assert a == b -@pytest.mark.parametrize("method", ["coh", "plv", "pli", "wpli", "ciplv", - "mic", "mim"]) +@pytest.mark.parametrize("method", ['coh', 'plv', 'pli', 'wpli', 'ciplv']) @pytest.mark.parametrize("indices", [None, - (np.array([0, 1]), np.array([2, 3])), - (np.array([[0, 1]]), np.array([[2, 3]])) - ]) + (np.array([0, 1]), np.array([2, 3]))]) def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): """Test that indices values and type is maintained after saving. @@ -1337,14 +1334,6 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): freqs = np.arange(10, 31) tmp_file = os.path.join(tmp_path, "foo_mvc.nc") - # mutlivariate and bivariate methods require the right indices shape - if method in ["mic", "mim"]: - if indices is not None and indices[0].ndim == 1: - pytest.skip() - else: - if indices is not None and indices[0].ndim == 2: - pytest.skip() - # test the pair of method and indices defined to check the output indices con_epochs = spectral_connectivity_epochs( epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30 diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 340779ef..4ac3f161 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -422,20 +422,19 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, indices_use]) np.ma.set_fill_value(indices_use, -1) # else 99999 after concat. if any(this_method in _gc_methods for this_method in method): - for seed, target in zip(indices[0], indices[1]): - intersection = np.intersect1d(seed, target) - if np.any(intersection != -1): # ignore padded entries + for seed, target in zip(indices_use[0], indices_use[1]): + intersection = np.intersect1d(seed.compressed(), + target.compressed()) + if intersection.size > 0: raise ValueError( 'seed and target indices must not intersect when ' 'computing Granger causality') # make sure padded indices are stored in the connectivity object - indices = tuple(np.array(indices_use)) # create a copy + # create a copy + indices = (indices_use[0].copy(), indices_use[1].copy()) else: indices_use = check_indices(indices) - # create copies of indices_use for independent manipulation - source_idx = np.array(indices_use[0]) - target_idx = np.array(indices_use[1]) - n_cons = len(source_idx) + n_cons = len(indices_use[0]) # unique signals for which we actually need to compute the CSD of if multivariate_con: diff --git a/mne_connectivity/tests/test_utils.py b/mne_connectivity/tests/test_utils.py index 31f27d0a..caead679 100644 --- a/mne_connectivity/tests/test_utils.py +++ b/mne_connectivity/tests/test_utils.py @@ -34,14 +34,22 @@ def test_seed_target_indices(): seeds = [[0, 1]] targets = [[2, 3], [3, 4]] indices = seed_target_multivariate_indices(seeds, targets) - assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]), - np.array([[2, 3], [3, 4]]))) + match_indices = (np.array([[0, 1], [0, 1]], dtype=object), + np.array([[2, 3], [3, 4]], dtype=object)) + for type_i in range(2): + for con_i in range(len(indices[0])): + assert np.all(indices[type_i][con_i] == + match_indices[type_i][con_i]) # ragged indices seeds = [[0, 1]] targets = [[2, 3, 4], [4]] indices = seed_target_multivariate_indices(seeds, targets) - assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]), - np.array([[2, 3, 4], [4, -1, -1]]))) + match_indices = (np.array([[0, 1], [0, 1]], dtype=object), + np.array([[2, 3, 4], [4]], dtype=object)) + for type_i in range(2): + for con_i in range(len(indices[0])): + assert np.all(indices[type_i][con_i] == + match_indices[type_i][con_i]) # test error catching # non-array-like seeds/targets with pytest.raises(TypeError, diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index 6696a130..2cab0007 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -98,7 +98,7 @@ def _check_multivariate_indices(indices, n_chans): Returns ------- indices : tuple of array of array of int, shape of (2, n_cons, max_n_chans) - The indices padded with the invalid channel index ``-1``. + The indices as a masked array. Notes ----- @@ -112,26 +112,35 @@ def _check_multivariate_indices(indices, n_chans): connection must be unique. If the seed and target indices are given as lists or tuples, they will be - converted to numpy arrays. In case the number of channels differs across + converted to numpy arrays. Because the number of channels can differ across connections or between the seeds and targets for a given connection (i.e. - ragged indices), the returned array will be padded with the invalid channel - index ``-1`` according to the maximum number of channels in the seed or - target of any one connection. E.g. the ragged indices of shape ``(2, - n_cons, variable)``:: + ragged/jagged indices), the returned array will be padded out to a 'full' + array with an invalid index (``-1``) according to the maximum number of + channels in the seed or target of any one connection. These invalid + entries are then masked and returned as numpy masked arrays. E.g. the + ragged indices of shape ``(2, n_cons, variable)``:: indices = ([[0, 1], [0, 1 ]], # seeds [[2, 3], [4, 5, 6]]) # targets - would be returned as:: + would be padded to full arrays:: + + indices = ([[0, 1, -1], [0, 1, -1]], # seeds + [[2, 3, -1], [4, 5, 6]]) # targets + + to have shape ``(2, n_cons, max_n_chans)``, where ``max_n_chans = 3``. The + invalid entries are then masked:: - indices = (np.array([[0, 1, -1], [0, 1, -1]]), # seeds - np.array([[2, 3, -1], [4, 5, -1]])) # targets + indices = ([[0, 1, --], [0, 1, --]], # seeds + [[2, 3, --], [4, 5, 6]]) # targets - where the indices have been padded with ``-1`` to have shape ``(2, n_cons, - max_n_chans)``, where ``max_n_chans = 3``. More information on working with - multivariate indices and handling connections where the number of seeds and - targets are not equal can be found in the - :doc:`../auto_examples/handling_ragged_arrays` example. + In case "indices" contains negative values to index channels, these will be + converted to the corresponding positive-valued index before any masking is + applied. + + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. """ if not isinstance(indices, tuple) or len(indices) != 2: raise ValueError('indices must be a tuple of length 2') @@ -141,6 +150,7 @@ def _check_multivariate_indices(indices, n_chans): 'have the same length') n_cons = len(indices[0]) + invalid = -1 max_n_chans = 0 for group_idx, group in enumerate(indices): @@ -166,15 +176,19 @@ def _check_multivariate_indices(indices, n_chans): indices[group_idx][con_idx][chan_idx] = chan % n_chans # pad indices to avoid ragged arrays - padded_indices = (np.full((n_cons, max_n_chans), -1, dtype=np.int32), - np.full((n_cons, max_n_chans), -1, dtype=np.int32)) + padded_indices = (np.full((n_cons, max_n_chans), invalid, dtype=np.int32), + np.full((n_cons, max_n_chans), invalid, dtype=np.int32)) con_i = 0 for seed, target in zip(indices[0], indices[1]): padded_indices[0][con_i, :len(seed)] = seed padded_indices[1][con_i, :len(target)] = target con_i += 1 - return padded_indices + # mask invalid indices + masked_indices = (np.ma.masked_values(padded_indices[0], invalid), + np.ma.masked_values(padded_indices[1], invalid)) + + return masked_indices def seed_target_indices(seeds, targets): @@ -236,8 +250,8 @@ def seed_target_multivariate_indices(seeds, targets): Returns ------- - indices : tuple of array of array of int, shape (2, n_cons, max_n_chans) - The indices padded with the invalid channel index ``-1``. + indices : tuple of array of array of int, shape (2, n_cons, variable) + The indices as a numpy object array. Notes ----- @@ -247,12 +261,8 @@ def seed_target_multivariate_indices(seeds, targets): channels in the data. The length of indices for each connection do not need to be equal. Furthermore, all indices within a connection must be unique. - ``seeds`` and ``targets`` will be expanded such that connectivity will be - computed between each set of seeds and targets. In case the number of - channels differs across connections or between the seeds and targets for a - given connection (i.e. ragged indices), the returned array will be padded - with the invalid channel index ``-1`` according to the maximum number of - channels in the seed or target of any one connection. E.g. ``seeds`` and + Because the number of channels per connection can vary, the indices are + stored as numpy arrays with ``dtype=object``. E.g. ``seeds`` and ``targets``:: seeds = [[0]] @@ -260,15 +270,15 @@ def seed_target_multivariate_indices(seeds, targets): would be returned as:: - indices = (np.array([[0, -1, -1], [0, -1, -1]]), # seeds - np.array([[1, 2, -1], [3, 4, 5]])) # targets + indices = (np.array([[0 ], [0 ]], dtype=object), # seeds + np.array([[1, 2], [3, 4, 5]], dtype=object)) # targets + + Even if the number of channels does not vary, the indices will still be + stored as object arrays for compatibility. - where the indices have been padded with ``-1`` to have shape ``(2, n_cons, - max_n_chans)``, where ``n_cons = n_unique_seeds * n_unique_targets`` and - ``max_n_chans = 3``. More information on working with multivariate indices - and handling connections where the number of seeds and targets are not - equal can be found in the :doc:`../auto_examples/handling_ragged_arrays` - example. + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. """ array_like = (np.ndarray, list, tuple) @@ -278,7 +288,6 @@ def seed_target_multivariate_indices(seeds, targets): ): raise TypeError('`seeds` and `targets` must be array-like') - n_chans = [] for inds in [*seeds, *targets]: if not isinstance(inds, array_like): raise TypeError( @@ -286,27 +295,15 @@ def seed_target_multivariate_indices(seeds, targets): if len(inds) != len(np.unique(inds)): raise ValueError( '`seeds` and `targets` cannot contain repeated channels') - n_chans.append(len(inds)) - max_n_chans = max(n_chans) - n_cons = len(seeds) * len(targets) - # pad indices to avoid ragged arrays - padded_seeds = np.full((len(seeds), max_n_chans), -1, dtype=np.int32) - padded_targets = np.full((len(targets), max_n_chans), -1, dtype=np.int32) - for con_i, seed in enumerate(seeds): - padded_seeds[con_i, :len(seed)] = seed - for con_i, target in enumerate(targets): - padded_targets[con_i, :len(target)] = target - - # create final indices - indices = (np.zeros((n_cons, max_n_chans), dtype=np.int32), - np.zeros((n_cons, max_n_chans), dtype=np.int32)) - con_i = 0 - for seed in padded_seeds: - for target in padded_targets: - indices[0][con_i] = seed - indices[1][con_i] = target - con_i += 1 + indices = [[], []] + for seed in seeds: + for target in targets: + indices[0].append(np.array(seed)) + indices[1].append(np.array(target)) + + indices = (np.array(indices[0], dtype=object), + np.array(indices[1], dtype=object)) return indices