Skip to content

Commit

Permalink
Merge pull request #167 from drammock/warning
Browse files Browse the repository at this point in the history
type-check data input to spectral_connectivity_time
  • Loading branch information
tsbinns authored Jan 23, 2024
2 parents 9d9f0a9 + c5f600e commit 043d888
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
4 changes: 4 additions & 0 deletions mne_connectivity/spectral/tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,10 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode):
indices = (np.array([[0, 1]]), np.array([[2, 3]]))
freqs = np.arange(10, 25 + 1)

# test type-checking of data
with pytest.raises(TypeError, match="must be an instance of Epochs or a NumPy arr"):
spectral_connectivity_time(data="foo", freqs=freqs)

# check bad indices without nested array caught
with pytest.raises(
TypeError, match="multivariate indices must contain array-likes"
Expand Down
3 changes: 2 additions & 1 deletion mne_connectivity/spectral/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mne.epochs import BaseEpochs
from mne.parallel import parallel_func
from mne.time_frequency import dpss_windows, tfr_array_morlet, tfr_array_multitaper
from mne.utils import logger, verbose
from mne.utils import _validate_type, logger, verbose

from ..base import EpochSpectralConnectivity, SpectralConnectivity
from ..utils import _check_multivariate_indices, check_indices, fill_doc
Expand Down Expand Up @@ -339,6 +339,7 @@ def spectral_connectivity_time(
events = None
event_id = None
# extract data from Epochs object
_validate_type(data, (np.ndarray, BaseEpochs), "`data`", "Epochs or a NumPy array")
if isinstance(data, BaseEpochs):
names = data.ch_names
sfreq = data.info["sfreq"]
Expand Down

0 comments on commit 043d888

Please sign in to comment.