Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
richardkoehler committed Jan 12, 2023
2 parents 9769f42 + ed40aef commit 388283c
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 209 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ jobs:
- add_ssh_keys:
fingerprints:
- "d6:88:6b:a0:80:bf:14:8e:80:2e:ed:11:52:26:37:68"
- "2d:3e:74:bc:d4:55:4f:75:dd:13:cf:59:ac:45:dc:de"

- run:
# push built docs into the `dev` directory on the `gh-pages` branch
Expand Down
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ Enhancements
- Improve the documentation of :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`).
- Add the option to average connectivity across epochs and frequencies in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`).
- Select multitaper frequencies automatically in :func:`mne_connectivity.spectral_connectivity_time` similarly to :func:`mne_connectivity.spectral_connectivity_epochs` by `Santeri Ruuskanen`_ (:gh:`104`).
- Add the ``ciPLV`` method in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`115`).
- Add the option to use the edges of each epoch as padding in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`115`).

Bug
~~~
Expand Down
86 changes: 72 additions & 14 deletions mne_connectivity/spectral/tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def test_epochs_tmin_tmax(kind):
assert len(w) == 1 # just one even though there were multiple epochs


@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli', 'ciplv'])
@pytest.mark.parametrize(
'mode', ['cwt_morlet', 'multitaper'])
@pytest.mark.parametrize('data_option', ['sync', 'random'])
Expand Down Expand Up @@ -504,11 +504,11 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option):
# hypothesized "connection"
freq_band_low_limit = (8.)
freq_band_high_limit = (13.)
cwt_freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1)
con = spectral_connectivity_time(data, method=method, mode=mode,
freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1)
con = spectral_connectivity_time(data, freqs, method=method, mode=mode,
sfreq=sfreq, fmin=freq_band_low_limit,
fmax=freq_band_high_limit,
cwt_freqs=cwt_freqs, n_jobs=1,
n_jobs=1,
faverage=True, average=True, sm_times=0)
assert con.shape == (n_channels ** 2, len(con.freqs))
con_matrix = con.get_data('dense')[..., 0]
Expand All @@ -526,12 +526,13 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option):
assert np.all(con_matrix) <= 0.5


@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli', 'ciplv'])
@pytest.mark.parametrize(
'cwt_freqs', [[8., 10.], [8, 10], 10., 10])
def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs):
'freqs', [[8., 10.], [8, 10], 10., 10])
@pytest.mark.parametrize('mode', ['cwt_morlet', 'multitaper'])
def test_spectral_connectivity_time_freqs(method, freqs, mode):
"""Test time-resolved spectral connectivity with int and float values for
cwt_freqs."""
freqs."""
rng = np.random.default_rng(0)
n_epochs = 5
n_channels = 3
Expand All @@ -552,10 +553,10 @@ def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs):
data[i, c] = np.squeeze(np.sin(x))
# the frequency band should contain the frequency at which there is a
# hypothesized "connection"
con = spectral_connectivity_time(data, method=method, mode='cwt_morlet',
sfreq=sfreq, fmin=np.min(cwt_freqs),
fmax=np.max(cwt_freqs),
cwt_freqs=cwt_freqs, n_jobs=1,
con = spectral_connectivity_time(data, freqs, method=method,
mode=mode, sfreq=sfreq,
fmin=np.min(freqs),
fmax=np.max(freqs), n_jobs=1,
faverage=True, average=True, sm_times=0)
assert con.shape == (n_channels ** 2, len(con.freqs))
con_matrix = con.get_data('dense')[..., 0]
Expand Down Expand Up @@ -588,12 +589,12 @@ def test_spectral_connectivity_time_resolved(method, mode):
info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
data = EpochsArray(data, info)

# define some frequencies for cwt
# define some frequencies for tfr
freqs = np.arange(3, 20.5, 1)

# run connectivity estimation
con = spectral_connectivity_time(
data, sfreq=sfreq, cwt_freqs=freqs, method=method, mode=mode,
data, freqs, sfreq=sfreq, method=method, mode=mode,
n_cycles=5)
assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs))
assert con.get_data(output='dense').shape == \
Expand All @@ -613,6 +614,63 @@ def test_spectral_connectivity_time_resolved(method, mode):
for idx, jdx in triu_inds)


@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
@pytest.mark.parametrize(
'mode', ['cwt_morlet', 'multitaper'])
@pytest.mark.parametrize('padding', [0, 1, 5])
def test_spectral_connectivity_time_padding(method, mode, padding):
"""Test time-resolved spectral connectivity with padding."""
sfreq = 50.
n_signals = 3
n_epochs = 2
n_times = 300
trans_bandwidth = 2.
tmin = 0.
tmax = (n_times - 1) / sfreq
# 5Hz..15Hz
fstart, fend = 5.0, 15.0
data, _ = create_test_dataset(
sfreq, n_signals=n_signals, n_epochs=n_epochs, n_times=n_times,
tmin=tmin, tmax=tmax,
fstart=fstart, fend=fend, trans_bandwidth=trans_bandwidth)
ch_names = np.arange(n_signals).astype(str).tolist()
info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
data = EpochsArray(data, info)

# define some frequencies for tfr
freqs = np.arange(3, 20.5, 1)

# run connectivity estimation
if padding == 5:
with pytest.raises(ValueError, match='Padding cannot be larger than '
'half of data length'):
con = spectral_connectivity_time(
data, freqs, sfreq=sfreq, method=method, mode=mode,
n_cycles=5, padding=padding)
return
else:
con = spectral_connectivity_time(
data, freqs, sfreq=sfreq, method=method, mode=mode,
n_cycles=5, padding=padding)

assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs))
assert con.get_data(output='dense').shape == \
(n_epochs, n_signals, n_signals, len(con.freqs))

# test the simulated signal
triu_inds = np.vstack(np.triu_indices(n_signals, k=1)).T

# average over frequencies
conn_data = con.get_data(output='dense').mean(axis=-1)

# the indices at which there is a correlation should be greater
# then the rest of the components
for epoch_idx in range(n_epochs):
high_conn_val = conn_data[epoch_idx, 0, 1]
assert all(high_conn_val >= conn_data[epoch_idx, idx, jdx]
for idx, jdx in triu_inds)


def test_save(tmp_path):
"""Test saving results of spectral connectivity."""
rng = np.random.RandomState(0)
Expand Down
Loading

0 comments on commit 388283c

Please sign in to comment.