diff --git a/doc/api.rst b/doc/api.rst index 300601b4..f919f74b 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -86,4 +86,14 @@ Visualization functions :toctree: generated/ plot_sensors_connectivity - plot_connectivity_circle \ No newline at end of file + plot_connectivity_circle + +Dataset functions +================= + +.. currentmodule:: mne_connectivity + +.. autosummary:: + :toctree: generated/ + + make_signals_in_freq_bands \ No newline at end of file diff --git a/mne_connectivity/__init__.py b/mne_connectivity/__init__.py index fc55e2d0..d271c706 100644 --- a/mne_connectivity/__init__.py +++ b/mne_connectivity/__init__.py @@ -17,6 +17,7 @@ SpectroTemporalConnectivity, TemporalConnectivity, ) +from .datasets import make_signals_in_freq_bands from .effective import phase_slope_index from .envelope import envelope_correlation, symmetric_orth from .io import read_connectivity diff --git a/mne_connectivity/datasets/__init__.py b/mne_connectivity/datasets/__init__.py new file mode 100644 index 00000000..d5c8e2eb --- /dev/null +++ b/mne_connectivity/datasets/__init__.py @@ -0,0 +1 @@ +from .frequency import make_signals_in_freq_bands diff --git a/mne_connectivity/datasets/frequency.py b/mne_connectivity/datasets/frequency.py new file mode 100644 index 00000000..8fd29308 --- /dev/null +++ b/mne_connectivity/datasets/frequency.py @@ -0,0 +1,153 @@ +# Authors: Adam Li +# Thomas S. Binns +# +# License: BSD (3-clause) + +import numpy as np +from mne import EpochsArray, create_info +from mne.filter import filter_data + + +def make_signals_in_freq_bands( + n_seeds, + n_targets, + freq_band, + n_epochs=10, + n_times=200, + sfreq=100, + trans_bandwidth=1, + snr=0.7, + connection_delay=5, + tmin=0, + ch_names=None, + ch_types="eeg", + rng_seed=None, +): + """Simulate signals interacting in a given frequency band. + + Parameters + ---------- + n_seeds : int + Number of seed channels to simulate. + n_targets : int + Number of target channels to simulate. + freq_band : tuple of int or float + Frequency band where the connectivity should be simulated, where the first entry + corresponds to the lower frequency, and the second entry to the higher + frequency. + n_epochs : int (default 10) + Number of epochs in the simulated data. + n_times : int (default 200) + Number of timepoints each epoch of the simulated data. + sfreq : int | float (default 100) + Sampling frequency of the simulated data, in Hz. + trans_bandwidth : int | float (default 1) + Transition bandwidth of the filter to apply to isolate activity in + ``freq_band``, in Hz. These are passed to the ``l_bandwidth`` and + ``h_bandwidth`` keyword arguments in :func:`mne.filter.create_filter`. + snr : float (default 0.7) + Signal-to-noise ratio of the simulated data in the range [0, 1]. + connection_delay : int (default 5) + Number of timepoints for the delay of connectivity between the seeds and + targets. If > 0, the target data is a delayed form of the seed data. If < 0, the + seed data is a delayed form of the target data. + tmin : int | float (default 0) + Earliest time of each epoch. + ch_names : list of str | None (default None) + Names of the channels in the simulated data. If `None`, the channels are named + according to their index and the frequency band of interaction. If specified, + must be a list of ``n_seeds + n_targets`` channel names. + ch_types : str | list of str (default "eeg") + Types of the channels in the simulated data. If specified as a list, must be a + list of ``n_seeds + n_targets`` channel names. + rng_seed : int | None (default None) + Seed to use for the random number generator. If `None`, no seed is specified. + + Returns + ------- + epochs : mne.EpochsArray of shape (n_epochs, ``n_seeds + n_targets``, n_times) + The simulated data stored in an `mne.EpochsArray` object. The channels are + arranged according to seeds, then targets. + + Notes + ----- + Signals are simulated as a single source of activity in the given frequency band and + projected into ``n_seeds + n_targets`` noise channels. + """ + n_channels = n_seeds + n_targets + + # check inputs + if n_seeds < 1 or n_targets < 1: + raise ValueError("Number of seeds and targets must each be at least 1.") + + if not isinstance(freq_band, tuple): + raise TypeError("Frequency band must be a tuple.") + if len(freq_band) != 2: + raise ValueError("Frequency band must contain two numbers.") + + if n_times < 1: + raise ValueError("Number of timepoints must be at least 1.") + + if n_epochs < 1: + raise ValueError("Number of epochs must be at least 1.") + + if sfreq <= 0: + raise ValueError("Sampling frequency must be > 0.") + + if snr < 0 or snr > 1: + raise ValueError("Signal-to-noise ratio must be between 0 and 1.") + + if np.abs(connection_delay) >= n_epochs * n_times: + raise ValueError( + "Connection delay must be less than the total number of timepoints." + ) + + # simulate data + rng = np.random.default_rng(rng_seed) + + # simulate signal source at desired frequency band + signal = rng.standard_normal( + size=(1, n_epochs * n_times + np.abs(connection_delay)) + ) + signal = filter_data( + data=signal, + sfreq=sfreq, + l_freq=freq_band[0], + h_freq=freq_band[1], + l_trans_bandwidth=trans_bandwidth, + h_trans_bandwidth=trans_bandwidth, + fir_design="firwin2", + ) + + # simulate noise for each channel + noise = rng.standard_normal( + size=(n_channels, n_epochs * n_times + np.abs(connection_delay)) + ) + + # create data by projecting signal into each channel of noise + data = (signal * snr) + (noise * (1 - snr)) + + # shift data by desired delay and remove extra time + if connection_delay != 0: + if connection_delay > 0: + delay_chans = np.arange(n_seeds, n_channels) # delay targets + else: + delay_chans = np.arange(0, n_seeds) # delay seeds + data[delay_chans, np.abs(connection_delay) :] = data[ + delay_chans, : n_epochs * n_times + ] + data = data[:, : n_epochs * n_times] + + # reshape data into epochs + data = data.reshape(n_channels, n_epochs, n_times) + data = data.transpose((1, 0, 2)) # (epochs x channels x times) + + # store data in an MNE EpochsArray object + if ch_names is None: + ch_names = [ + f"{ch_i}_{freq_band[0]}_{freq_band[1]}" for ch_i in range(n_channels) + ] + info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) + epochs = EpochsArray(data=data, info=info, tmin=tmin) + + return epochs diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 96e18d6c..8c493cb3 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -1,3 +1,4 @@ +import inspect import os import numpy as np @@ -9,6 +10,7 @@ from mne_connectivity import ( SpectralConnectivity, + make_signals_in_freq_bands, read_connectivity, spectral_connectivity_epochs, spectral_connectivity_time, @@ -21,6 +23,7 @@ from mne_connectivity.spectral.epochs_bivariate import _CohEst +# TODO: Replace with `make_signals_in_freq_bands` after tweaking tolerances in tests def create_test_dataset( sfreq, n_signals, n_epochs, n_times, tmin, tmax, fstart, fend, trans_bandwidth=2.0 ): @@ -103,34 +106,18 @@ def _stc_gen(data, sfreq, tmin, combo=False): @pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) def test_spectral_connectivity_parallel(method, mode, tmp_path): """Test saving spectral connectivity with parallel functions.""" - # Use a case known to have no spurious correlations (it would bad if - # tests could randomly fail): - rng = np.random.RandomState(0) - trans_bandwidth = 2.0 - - sfreq = 50.0 - n_signals = 3 - n_epochs = 8 - n_times = 256 n_jobs = 2 # test with parallelization - data = rng.randn(n_signals, n_epochs * n_times) - # simulate connectivity from 5Hz..15Hz - fstart, fend = 5.0, 15.0 - data[1, :] = filter_data( - data[0, :], - sfreq, - fstart, - fend, - filter_length="auto", - fir_design="firwin2", - l_trans_bandwidth=trans_bandwidth, - h_trans_bandwidth=trans_bandwidth, + data = make_signals_in_freq_bands( + n_seeds=2, + n_targets=1, + freq_band=(5, 15), + n_epochs=8, + n_times=256, + sfreq=50, + trans_bandwidth=2.0, + rng_seed=0, # case with no spurious correlations (avoid tests randomly failing) ) - # add some noise, so the spectrum is not exactly zero - data[1, :] += 1e-2 * rng.randn(n_times * n_epochs) - data = data.reshape(n_signals, n_epochs, n_times) - data = np.transpose(data, [1, 0, 2]) # define some frequencies for cwt cwt_freqs = np.arange(3, 24.5, 1) @@ -158,7 +145,6 @@ def test_spectral_connectivity_parallel(method, mode, tmp_path): method=method, mode=mode, indices=None, - sfreq=sfreq, mt_adaptive=adaptive, mt_low_bias=True, mt_bandwidth=mt_bandwidth, @@ -210,6 +196,7 @@ def test_spectral_connectivity(method, mode): # 5Hz..15Hz fstart, fend = 5.0, 15.0 + # TODO: Replace with `make_signals_in_freq_bands` after tweaking tolerances in tests data, times_data = create_test_dataset( sfreq, n_signals=n_signals, @@ -501,6 +488,7 @@ def test_spectral_connectivity_epochs_multivariate(method): # 15-25 Hz connectivity fstart, fend = 15.0, 25.0 rng = np.random.RandomState(0) + # TODO: Replace with `make_signals_in_freq_bands` after tweaking tolerances in tests data = rng.randn(n_signals, n_epochs * n_times + delay) # simulate connectivity from fstart to fend data[n_seeds:, :] = filter_data( @@ -732,12 +720,17 @@ def test_multivariate_spectral_connectivity_epochs_regression(): @pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): """Test error catching for multivar. freq.-domain connectivity methods.""" - sfreq = 50.0 - n_signals = 4 # Do not change! - n_epochs = 8 - n_times = 256 - rng = np.random.RandomState(0) - data = rng.randn(n_epochs, n_signals, n_times) + sfreq = 50 # Hz + data = make_signals_in_freq_bands( + n_seeds=2, # do not change! + n_targets=2, # do not change! + freq_band=(10, 20), # arbitrary for this test + n_epochs=8, + n_times=256, + sfreq=sfreq, + rng_seed=0, + ) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) cwt_freqs = np.arange(10, 25 + 1) @@ -747,12 +740,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): ): non_nested_indices = (np.array([0, 1]), np.array([2, 3])) spectral_connectivity_epochs( - data, - method=method, - mode=mode, - indices=non_nested_indices, - sfreq=sfreq, - gc_n_lags=10, + data, method=method, mode=mode, indices=non_nested_indices, gc_n_lags=10 ) # check bad indices with repeated channels caught @@ -761,12 +749,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): ): repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) spectral_connectivity_epochs( - data, - method=method, - mode=mode, - indices=repeated_indices, - sfreq=sfreq, - gc_n_lags=10, + data, method=method, mode=mode, indices=repeated_indices, gc_n_lags=10 ) # check mixed methods caught @@ -776,12 +759,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): elif isinstance(method, list): mixed_methods = [*method, "coh"] spectral_connectivity_epochs( - data, - method=mixed_methods, - mode=mode, - indices=indices, - sfreq=sfreq, - cwt_freqs=cwt_freqs, + data, method=mixed_methods, mode=mode, indices=indices, cwt_freqs=cwt_freqs ) # check bad rank args caught @@ -792,7 +770,6 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): method=method, mode=mode, indices=indices, - sfreq=sfreq, rank=too_low_rank, cwt_freqs=cwt_freqs, ) @@ -803,7 +780,6 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): method=method, mode=mode, indices=indices, - sfreq=sfreq, rank=too_high_rank, cwt_freqs=cwt_freqs, ) @@ -814,7 +790,6 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): method=method, mode=mode, indices=indices, - sfreq=sfreq, rank=too_few_rank, cwt_freqs=cwt_freqs, ) @@ -825,13 +800,16 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): method=method, mode=mode, indices=indices, - sfreq=sfreq, rank=too_much_rank, cwt_freqs=cwt_freqs, ) # check rank-deficient data caught - bad_data = data.copy() + # XXX: remove logic once support for mne<1.6 is dropped + kwargs = dict() + if "copy" in inspect.getfullargspec(data.get_data).kwonlyargs: + kwargs["copy"] = False + bad_data = data.get_data(**kwargs) bad_data[:, 1] = bad_data[:, 0] bad_data[:, 3] = bad_data[:, 2] assert np.all(np.linalg.matrix_rank(bad_data[:, (0, 1), :]) == 1) @@ -872,7 +850,6 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): method=method, mode=mode, indices=indices, - sfreq=sfreq, fmin=frange[0], fmax=frange[1], gc_n_lags=n_lags, @@ -882,12 +859,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): # check no indices caught with pytest.raises(ValueError, match="indices must be specified"): spectral_connectivity_epochs( - data, - method=method, - mode=mode, - indices=None, - sfreq=sfreq, - cwt_freqs=cwt_freqs, + data, method=method, mode=mode, indices=None, cwt_freqs=cwt_freqs ) # check intersecting indices caught @@ -896,12 +868,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): ValueError, match="seed and target indices must not intersect" ): spectral_connectivity_epochs( - data, - method=method, - mode=mode, - indices=bad_indices, - sfreq=sfreq, - cwt_freqs=cwt_freqs, + data, method=method, mode=mode, indices=bad_indices, cwt_freqs=cwt_freqs ) # check bad fmin/fmax caught @@ -911,7 +878,6 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): method=method, mode=mode, indices=indices, - sfreq=sfreq, fmin=(10.0, 15.0), fmax=(15.0, 20.0), cwt_freqs=cwt_freqs, @@ -933,22 +899,20 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): @pytest.mark.parametrize("method", ["mic", "mim", "gc", "gc_tr"]) def test_multivar_spectral_connectivity_parallel(method): """Test multivar. freq.-domain connectivity methods run in parallel.""" - sfreq = 50.0 - n_signals = 4 # Do not change! - n_epochs = 8 - n_times = 256 - rng = np.random.RandomState(0) - data = rng.randn(n_epochs, n_signals, n_times) + data = make_signals_in_freq_bands( + n_seeds=2, # do not change! + n_targets=2, # do not change! + freq_band=(10, 20), # arbitrary for this test + n_epochs=8, + n_times=256, + sfreq=50, + rng_seed=0, + ) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) spectral_connectivity_epochs( - data, - method=method, - mode="multitaper", - indices=indices, - sfreq=sfreq, - gc_n_lags=10, - n_jobs=2, + data, method=method, mode="multitaper", indices=indices, gc_n_lags=10, n_jobs=2 ) spectral_connectivity_time( data, @@ -956,7 +920,6 @@ def test_multivar_spectral_connectivity_parallel(method): method=method, mode="multitaper", indices=indices, - sfreq=sfreq, gc_n_lags=10, n_jobs=2, ) @@ -964,12 +927,16 @@ def test_multivar_spectral_connectivity_parallel(method): def test_multivar_spectral_connectivity_flipped_indices(): """Test multivar. indices structure maintained by connectivity methods.""" - sfreq = 50.0 - n_signals = 4 - n_epochs = 8 - n_times = 256 - rng = np.random.RandomState(0) - data = rng.randn(n_epochs, n_signals, n_times) + data = make_signals_in_freq_bands( + n_seeds=2, # do not change! + n_targets=2, # do not change! + freq_band=(10, 20), # arbitrary for this test + n_epochs=8, + n_times=256, + sfreq=50, + rng_seed=0, + ) + freqs = np.arange(10, 20) # if we're not careful, when finding the channels we need to compute the @@ -982,26 +949,26 @@ def test_multivar_spectral_connectivity_flipped_indices(): method = "gc" con_st = spectral_connectivity_epochs( # seed -> target - data, method=method, indices=indices, sfreq=sfreq, gc_n_lags=10 + data, method=method, indices=indices, gc_n_lags=10 ) con_ts = spectral_connectivity_epochs( # target -> seed - data, method=method, indices=flipped_indices, sfreq=sfreq, gc_n_lags=10 + data, method=method, indices=flipped_indices, gc_n_lags=10 ) con_st_ts = spectral_connectivity_epochs( # seed -> target; target -> seed - data, method=method, indices=concat_indices, sfreq=sfreq, gc_n_lags=10 + data, method=method, indices=concat_indices, gc_n_lags=10 ) assert not np.all(con_st.get_data() == con_ts.get_data()) assert np.all(con_st.get_data()[0] == con_st_ts.get_data()[0]) assert np.all(con_ts.get_data()[0] == con_st_ts.get_data()[1]) con_st = spectral_connectivity_time( # seed -> target - data, freqs, method=method, indices=indices, sfreq=sfreq, gc_n_lags=10 + data, freqs, method=method, indices=indices, gc_n_lags=10 ) con_ts = spectral_connectivity_time( # target -> seed - data, freqs, method=method, indices=flipped_indices, sfreq=sfreq, gc_n_lags=10 + data, freqs, method=method, indices=flipped_indices, gc_n_lags=10 ) con_st_ts = spectral_connectivity_time( # seed -> target; target -> seed - data, freqs, method=method, indices=concat_indices, sfreq=sfreq, gc_n_lags=10 + data, freqs, method=method, indices=concat_indices, gc_n_lags=10 ) assert not np.all(con_st.get_data() == con_ts.get_data()) assert np.all(con_st.get_data()[:, 0] == con_st_ts.get_data()[:, 0]) @@ -1183,6 +1150,7 @@ def test_spectral_connectivity_time_delayed(): # 20-30 Hz connectivity fstart, fend = 20.0, 30.0 rng = np.random.RandomState(0) + # TODO: Replace with `make_signals_in_freq_bands` after tweaking tolerances in tests data = rng.randn(n_signals, n_epochs * n_times + delay) # simulate connectivity from fstart to fend data[n_seeds:, :] = filter_data( @@ -1313,6 +1281,7 @@ def test_spectral_connectivity_time_resolved(method, mode): tmax = (n_times - 1) / sfreq # 5Hz..15Hz fstart, fend = 5.0, 15.0 + # TODO: Replace with `make_signals_in_freq_bands` after tweaking tolerances in tests data, _ = create_test_dataset( sfreq, n_signals=n_signals, @@ -1372,6 +1341,7 @@ def test_spectral_connectivity_time_padding(method, mode, padding): tmax = (n_times - 1) / sfreq # 5Hz..15Hz fstart, fend = 5.0, 15.0 + # TODO: Replace with `make_signals_in_freq_bands` after tweaking tolerances in tests data, _ = create_test_dataset( sfreq, n_signals=n_signals, @@ -1444,12 +1414,17 @@ def test_spectral_connectivity_time_padding(method, mode, padding): @pytest.mark.parametrize("faverage", [True, False]) def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): """Test result shapes of time-resolved multivar. connectivity methods.""" - sfreq = 50.0 - n_signals = 4 # Do not change! n_epochs = 8 - n_times = 500 - rng = np.random.RandomState(0) - data = rng.randn(n_epochs, n_signals, n_times) + data = make_signals_in_freq_bands( + n_seeds=2, # do not change! + n_targets=2, # do not change! + freq_band=(10, 20), # arbitrary for this test + n_epochs=n_epochs, + n_times=256, + sfreq=50, + rng_seed=0, + ) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) n_cons = len(indices[0]) freqs = np.arange(10, 25 + 1) @@ -1468,7 +1443,6 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): freqs, indices=indices, method=method, - sfreq=sfreq, faverage=faverage, average=average, gc_n_lags=10, @@ -1497,7 +1471,6 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): freqs, indices=indices, method=method, - sfreq=sfreq, faverage=faverage, average=average, gc_n_lags=10, @@ -1520,11 +1493,19 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): @pytest.mark.parametrize("mode", ["multitaper", "cwt_morlet"]) def test_multivar_spectral_connectivity_time_error_catch(method, mode): """Test error catching for time-resolved multivar. connectivity methods.""" - sfreq = 50.0 - n_signals = 4 # Do not change! - n_epochs = 8 - n_times = 256 - data = np.random.rand(n_epochs, n_signals, n_times) + n_seeds = 2 # do not change! + n_targets = 2 # do not change! + n_signals = n_seeds + n_targets + data = make_signals_in_freq_bands( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=(10, 20), # arbitrary for this test + n_epochs=8, + n_times=256, + sfreq=50, + rng_seed=0, + ) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) freqs = np.arange(10, 25 + 1) @@ -1538,12 +1519,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): ): non_nested_indices = (np.array([0, 1]), np.array([2, 3])) spectral_connectivity_time( - data, - freqs, - method=method, - mode=mode, - indices=non_nested_indices, - sfreq=sfreq, + data, freqs, method=method, mode=mode, indices=non_nested_indices ) # check bad indices with repeated channels caught @@ -1552,66 +1528,42 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): ): repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) spectral_connectivity_time( - data, freqs, method=method, mode=mode, indices=repeated_indices, sfreq=sfreq + data, freqs, method=method, mode=mode, indices=repeated_indices ) # check mixed methods caught with pytest.raises(ValueError, match="bivariate and multivariate connectivity"): mixed_methods = [method, "coh"] spectral_connectivity_time( - data, freqs, method=mixed_methods, mode=mode, indices=indices, sfreq=sfreq + data, freqs, method=mixed_methods, mode=mode, indices=indices ) # check bad rank args caught too_low_rank = (np.array([0]), np.array([0])) with pytest.raises(ValueError, match="ranks for seeds and targets must be"): spectral_connectivity_time( - data, - freqs, - method=method, - indices=indices, - sfreq=sfreq, - mode=mode, - rank=too_low_rank, + data, freqs, method=method, indices=indices, mode=mode, rank=too_low_rank ) too_high_rank = (np.array([3]), np.array([3])) with pytest.raises(ValueError, match="ranks for seeds and targets must be"): spectral_connectivity_time( - data, - freqs, - method=method, - indices=indices, - sfreq=sfreq, - mode=mode, - rank=too_high_rank, + data, freqs, method=method, indices=indices, mode=mode, rank=too_high_rank ) too_few_rank = ([], []) with pytest.raises(ValueError, match="rank argument must have shape"): spectral_connectivity_time( - data, - freqs, - method=method, - indices=indices, - sfreq=sfreq, - mode=mode, - rank=too_few_rank, + data, freqs, method=method, indices=indices, mode=mode, rank=too_few_rank ) too_much_rank = (np.array([2, 2]), np.array([2, 2])) with pytest.raises(ValueError, match="rank argument must have shape"): spectral_connectivity_time( - data, - freqs, - method=method, - indices=indices, - sfreq=sfreq, - mode=mode, - rank=too_much_rank, + data, freqs, method=method, indices=indices, mode=mode, rank=too_much_rank ) # check all-to-all conn. computed for MIC/MIM when no indices given if method in ["mic", "mim"]: con = spectral_connectivity_time( - data, freqs, method=method, indices=None, sfreq=sfreq, mode=mode + data, freqs, method=method, indices=None, mode=mode ) assert con.indices is None assert con.n_nodes == n_signals @@ -1622,7 +1574,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): # check no indices caught with pytest.raises(ValueError, match="indices must be specified"): spectral_connectivity_time( - data, freqs, method=method, mode=mode, indices=None, sfreq=sfreq + data, freqs, method=method, mode=mode, indices=None ) # check intersecting indices caught @@ -1631,7 +1583,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): ValueError, match="seed and target indices must not intersect" ): spectral_connectivity_time( - data, freqs, method=method, mode=mode, indices=bad_indices, sfreq=sfreq + data, freqs, method=method, mode=mode, indices=bad_indices ) # check bad fmin/fmax caught @@ -1642,7 +1594,6 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): method=method, mode=mode, indices=indices, - sfreq=sfreq, fmin=(5.0, 15.0), fmax=(15.0, 30.0), ) @@ -1650,14 +1601,15 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): def test_save(tmp_path): """Test saving results of spectral connectivity.""" - rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq, f = 10, 3, 2000, 1000.0, 20.0 - data = rng.randn(n_epochs, n_chs, n_times) - sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000) - data[:, :, 500:1500] += sig - info = create_info(n_chs, sfreq, "eeg") - tmin = -1 - epochs = EpochsArray(data, info, tmin=tmin) + epochs = make_signals_in_freq_bands( + n_seeds=2, + n_targets=1, + freq_band=(18, 22), # arbitrary for this test + n_epochs=10, + n_times=2000, + sfreq=1000, + rng_seed=0, + ) conn = spectral_connectivity_epochs( epochs, fmin=(4, 8, 13, 30), fmax=(8, 13, 30, 45), faverage=True @@ -1667,14 +1619,16 @@ def test_save(tmp_path): def test_multivar_save_load(tmp_path): """Test saving and loading results of multivariate connectivity.""" - rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq, f = 5, 4, 2000, 1000.0, 20.0 - data = rng.randn(n_epochs, n_chs, n_times) - sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000) - data[:, :, 500:1500] += sig - info = create_info(n_chs, sfreq, "eeg") - tmin = -1 - epochs = EpochsArray(data, info, tmin=tmin) + epochs = make_signals_in_freq_bands( + n_seeds=2, + n_targets=2, + freq_band=(18, 22), # arbitrary for this test + n_epochs=5, + n_times=2000, + sfreq=1000, + rng_seed=0, + ) + tmp_file = os.path.join(tmp_path, "foo_mvc.nc") non_ragged_indices = (np.array([[0, 1]]), np.array([[2, 3]])) @@ -1684,7 +1638,6 @@ def test_multivar_save_load(tmp_path): epochs, method=["mic", "mim", "gc", "gc_tr"], indices=indices, - sfreq=sfreq, fmin=10, fmax=30, ) @@ -1712,22 +1665,24 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): be None, otherwise, `indices` should be a tuple. The type of `indices` and its values should be retained after saving and reloading. """ - rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq = 5, 4, 200, 100.0 - data = rng.randn(n_epochs, n_chs, n_times) - info = create_info(n_chs, sfreq, "eeg") - tmin = -1 - epochs = EpochsArray(data, info, tmin=tmin) + epochs = make_signals_in_freq_bands( + n_seeds=2, + n_targets=2, + freq_band=(18, 22), # arbitrary for this test + n_epochs=10, + n_times=200, + sfreq=100, + rng_seed=0, + ) + freqs = np.arange(10, 31) tmp_file = os.path.join(tmp_path, "foo_mvc.nc") # test the pair of method and indices defined to check the output indices con_epochs = spectral_connectivity_epochs( - epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30 - ) - con_time = spectral_connectivity_time( - epochs, freqs, method=method, indices=indices, sfreq=sfreq + epochs, method=method, indices=indices, fmin=10, fmax=30 ) + con_time = spectral_connectivity_time(epochs, freqs, method=method, indices=indices) for con in [con_epochs, con_time]: con.save(tmp_file) @@ -1753,12 +1708,16 @@ def test_multivar_spectral_connectivity_indices_roundtrip_io(tmp_path, method, i be None, otherwise, `indices` should be a tuple. The type of `indices` and its values should be retained after saving and reloading. """ - rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq = 5, 4, 200, 100.0 - data = rng.randn(n_epochs, n_chs, n_times) - info = create_info(n_chs, sfreq, "eeg") - tmin = -1 - epochs = EpochsArray(data, info, tmin=tmin) + epochs = make_signals_in_freq_bands( + n_seeds=2, + n_targets=2, + freq_band=(18, 22), # arbitrary for this test + n_epochs=10, + n_times=200, + sfreq=100, + rng_seed=0, + ) + freqs = np.arange(10, 31) tmp_file = os.path.join(tmp_path, "foo_mvc.nc") @@ -1768,16 +1727,10 @@ def test_multivar_spectral_connectivity_indices_roundtrip_io(tmp_path, method, i pytest.skip() con_epochs = spectral_connectivity_epochs( - epochs, - method=method, - indices=indices, - sfreq=sfreq, - fmin=10, - fmax=30, - gc_n_lags=10, + epochs, method=method, indices=indices, fmin=10, fmax=30, gc_n_lags=10 ) con_time = spectral_connectivity_time( - epochs, freqs, method=method, indices=indices, sfreq=sfreq, gc_n_lags=10 + epochs, freqs, method=method, indices=indices, gc_n_lags=10 ) for con in [con_epochs, con_time]: diff --git a/mne_connectivity/tests/test_datasets.py b/mne_connectivity/tests/test_datasets.py new file mode 100644 index 00000000..4ae7d5ac --- /dev/null +++ b/mne_connectivity/tests/test_datasets.py @@ -0,0 +1,169 @@ +import numpy as np +import pytest + +from mne_connectivity import ( + make_signals_in_freq_bands, + seed_target_indices, + spectral_connectivity_epochs, +) + + +@pytest.mark.parametrize("n_seeds", [1, 3]) +@pytest.mark.parametrize("n_targets", [1, 3]) +@pytest.mark.parametrize("snr", [0.7, 0.4]) +@pytest.mark.parametrize("connection_delay", [0, 3, -3]) +@pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) +def test_make_signals_in_freq_bands(n_seeds, n_targets, snr, connection_delay, mode): + """Test `make_signals_in_freq_bands` simulates connectivity properly.""" + # Case with no spurious correlations (avoids tests randomly failing) + rng_seed = 0 + + # Simulate data + freq_band = (5, 10) # fmin, fmax (Hz) + sfreq = 100 # Hz + trans_bandwidth = 1 # Hz + data = make_signals_in_freq_bands( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=freq_band, + n_epochs=30, + n_times=200, + sfreq=sfreq, + trans_bandwidth=trans_bandwidth, + snr=snr, + connection_delay=connection_delay, + rng_seed=rng_seed, + ) + + # Compute connectivity + methods = ["coh", "imcoh", "dpli"] + indices = seed_target_indices( + seeds=np.arange(n_seeds), targets=np.arange(n_targets) + n_seeds + ) + fmin = 3 + fmax = sfreq // 2 + if mode == "cwt_morlet": + cwt_params = {"cwt_freqs": np.arange(fmin, fmax), "cwt_n_cycles": 3.5} + else: + cwt_params = dict() + con = spectral_connectivity_epochs( + data, + method=methods, + indices=indices, + mode=mode, + fmin=fmin, + fmax=fmax, + **cwt_params, + ) + freqs = np.array(con[0].freqs) + + # Define expected connectivity values + thresh_good = dict() + thresh_bad = dict() + # Coh + thresh_good["coh"] = (0.2, 0.9) + thresh_bad["coh"] = (0.0, 0.2) + # ImCoh + if connection_delay == 0: + thresh_good["imcoh"] = (0.0, 0.17) + thresh_bad["imcoh"] = (0.0, 0.17) + else: + thresh_good["imcoh"] = (0.17, 0.8) + thresh_bad["imcoh"] = (0.0, 0.17) + # DPLI + if connection_delay == 0: + thresh_good["dpli"] = (0.3, 0.6) + thresh_bad["dpli"] = (0.3, 0.6) + elif connection_delay > 0: + thresh_good["dpli"] = (0.5, 1) + thresh_bad["dpli"] = (0.3, 0.6) + else: + thresh_good["dpli"] = (0, 0.5) + thresh_bad["dpli"] = (0.3, 0.6) + + # Check connectivity values are acceptable + freqs_good = np.argwhere( + (freqs >= freq_band[0]) & (freqs <= freq_band[1]) + ).flatten() + freqs_bad = np.argwhere( + (freqs < freq_band[0] - trans_bandwidth * 2) + | (freqs > freq_band[1] + trans_bandwidth * 2) + ).flatten() + for method_name, method_con in zip(methods, con): + con_values = method_con.get_data() + if method_name == "imcoh": + con_values = np.abs(con_values) + # freq. band of interest + con_values_good = np.mean(con_values[:, freqs_good]) + assert ( + con_values_good >= thresh_good[method_name][0] + and con_values_good <= thresh_good[method_name][1] + ) + + # other freqs. + con_values_bad = np.mean(con_values[:, freqs_bad]) + assert ( + con_values_bad >= thresh_bad[method_name][0] + and con_values_bad <= thresh_bad[method_name][1] + ) + + +def test_make_signals_error_catch(): + """Test error catching for `make_signals_in_freq_bands`.""" + freq_band = (5, 10) + + # check bad n_seeds/targets caught + with pytest.raises( + ValueError, match="Number of seeds and targets must each be at least 1." + ): + make_signals_in_freq_bands(n_seeds=0, n_targets=1, freq_band=freq_band) + with pytest.raises( + ValueError, match="Number of seeds and targets must each be at least 1." + ): + make_signals_in_freq_bands(n_seeds=1, n_targets=0, freq_band=freq_band) + + # check bad freq_band caught + with pytest.raises(TypeError, match="Frequency band must be a tuple."): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=1) + with pytest.raises(ValueError, match="Frequency band must contain two numbers."): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=(1, 2, 3)) + + # check bad n_times + with pytest.raises(ValueError, match="Number of timepoints must be at least 1."): + make_signals_in_freq_bands( + n_seeds=1, n_targets=1, freq_band=freq_band, n_times=0 + ) + + # check bad n_epochs + with pytest.raises(ValueError, match="Number of epochs must be at least 1."): + make_signals_in_freq_bands( + n_seeds=1, n_targets=1, freq_band=freq_band, n_epochs=0 + ) + + # check bad sfreq + with pytest.raises(ValueError, match="Sampling frequency must be > 0."): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, sfreq=0) + + # check bad snr + with pytest.raises( + ValueError, match="Signal-to-noise ratio must be between 0 and 1." + ): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, snr=-1) + with pytest.raises( + ValueError, match="Signal-to-noise ratio must be between 0 and 1." + ): + make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, snr=2) + + # check bad connection_delay + with pytest.raises( + ValueError, + match="Connection delay must be less than the total number of timepoints.", + ): + make_signals_in_freq_bands( + n_seeds=1, + n_targets=1, + freq_band=freq_band, + n_epochs=1, + n_times=1, + connection_delay=1, + )