From 241b6fffc919d326ca31996311edbacd115f142f Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 28 Nov 2023 15:00:55 +0100 Subject: [PATCH] [BUG] Fixed issue w/ different rank-indices length (#158) --- .../spectral/epochs_multivariate.py | 6 ++++++ .../spectral/tests/test_spectral.py | 20 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index 91de421b..2a5d0816 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -48,6 +48,12 @@ def _check_rank_input(rank, data, indices): rank = tuple((np.array(rank[0]), np.array(rank[1]))) else: + if ( + len(rank) != 2 or len(rank[0]) != len(indices[0]) or + len(rank[1]) != len(indices[1]) + ): + raise ValueError('rank argument must have shape (2, n_cons), ' + 'according to n_cons in the indices') for seed_idcs, target_idcs, seed_rank, target_rank in zip( indices[0], indices[1], rank[0], rank[1]): if not (0 < seed_rank <= len(seed_idcs) and diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index eea3bc04..2cb21403 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -657,6 +657,16 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=too_high_rank, cwt_freqs=cwt_freqs) + too_few_rank = ([], []) + with pytest.raises(ValueError, match='rank argument must have shape'): + spectral_connectivity_epochs( + data, method=method, mode=mode, indices=indices, + sfreq=sfreq, rank=too_few_rank, cwt_freqs=cwt_freqs) + too_much_rank = (np.array([2, 2]), np.array([2, 2])) + with pytest.raises(ValueError, match='rank argument must have shape'): + spectral_connectivity_epochs( + data, method=method, mode=mode, indices=indices, + sfreq=sfreq, rank=too_much_rank, cwt_freqs=cwt_freqs) # check rank-deficient data caught bad_data = data.copy() @@ -1236,6 +1246,16 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): spectral_connectivity_time( data, freqs, method=method, indices=indices, sfreq=sfreq, mode=mode, rank=too_high_rank) + too_few_rank = ([], []) + with pytest.raises(ValueError, match='rank argument must have shape'): + spectral_connectivity_time( + data, freqs, method=method, indices=indices, sfreq=sfreq, + mode=mode, rank=too_few_rank) + too_much_rank = (np.array([2, 2]), np.array([2, 2])) + with pytest.raises(ValueError, match='rank argument must have shape'): + spectral_connectivity_time( + data, freqs, method=method, indices=indices, sfreq=sfreq, + mode=mode, rank=too_much_rank) # check all-to-all conn. computed for MIC/MIM when no indices given if method in ['mic', 'mim']: