Skip to content

Commit

Permalink
[BUG] Fixed issue w/ different rank-indices length (#158)
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns authored Nov 28, 2023
1 parent c523fe0 commit 241b6ff
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
6 changes: 6 additions & 0 deletions mne_connectivity/spectral/epochs_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def _check_rank_input(rank, data, indices):
rank = tuple((np.array(rank[0]), np.array(rank[1])))

else:
if (
len(rank) != 2 or len(rank[0]) != len(indices[0]) or
len(rank[1]) != len(indices[1])
):
raise ValueError('rank argument must have shape (2, n_cons), '
'according to n_cons in the indices')
for seed_idcs, target_idcs, seed_rank, target_rank in zip(
indices[0], indices[1], rank[0], rank[1]):
if not (0 < seed_rank <= len(seed_idcs) and
Expand Down
20 changes: 20 additions & 0 deletions mne_connectivity/spectral/tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,16 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode):
spectral_connectivity_epochs(
data, method=method, mode=mode, indices=indices,
sfreq=sfreq, rank=too_high_rank, cwt_freqs=cwt_freqs)
too_few_rank = ([], [])
with pytest.raises(ValueError, match='rank argument must have shape'):
spectral_connectivity_epochs(
data, method=method, mode=mode, indices=indices,
sfreq=sfreq, rank=too_few_rank, cwt_freqs=cwt_freqs)
too_much_rank = (np.array([2, 2]), np.array([2, 2]))
with pytest.raises(ValueError, match='rank argument must have shape'):
spectral_connectivity_epochs(
data, method=method, mode=mode, indices=indices,
sfreq=sfreq, rank=too_much_rank, cwt_freqs=cwt_freqs)

# check rank-deficient data caught
bad_data = data.copy()
Expand Down Expand Up @@ -1236,6 +1246,16 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode):
spectral_connectivity_time(
data, freqs, method=method, indices=indices, sfreq=sfreq,
mode=mode, rank=too_high_rank)
too_few_rank = ([], [])
with pytest.raises(ValueError, match='rank argument must have shape'):
spectral_connectivity_time(
data, freqs, method=method, indices=indices, sfreq=sfreq,
mode=mode, rank=too_few_rank)
too_much_rank = (np.array([2, 2]), np.array([2, 2]))
with pytest.raises(ValueError, match='rank argument must have shape'):
spectral_connectivity_time(
data, freqs, method=method, indices=indices, sfreq=sfreq,
mode=mode, rank=too_much_rank)

# check all-to-all conn. computed for MIC/MIM when no indices given
if method in ['mic', 'mim']:
Expand Down

0 comments on commit 241b6ff

Please sign in to comment.