Skip to content

Commit

Permalink
[MRG] [BUG] [ENH] [WIP] Bug fixes and enhancements for time-resolved …
Browse files Browse the repository at this point in the history
…spectral connectivity estimation (#104)
  • Loading branch information
ruuskas authored Nov 21, 2022
1 parent 53b6b16 commit 4476efa
Show file tree
Hide file tree
Showing 11 changed files with 568 additions and 320 deletions.
1 change: 1 addition & 0 deletions doc/authors.inc
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
.. _Szonja Weigl: https://github.com/weiglszonja
.. _Kenji Marshall: https://github.com/kenjimarshall
.. _Sezan Mert: https://github.com/SezanMert
.. _Santeri Ruuskanen: https://github.com/ruuskas
3 changes: 2 additions & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@
'use_edit_page_button': False,
'navigation_with_keys': False,
'show_toc_level': 1,
'navbar_end': ['version-switcher', 'navbar-icon-links'],
'navbar_end': ['theme-switcher', 'version-switcher', 'navbar-icon-links'],
'secondary_sidebar_items': ['page-toc'],
}
# Custom sidebar templates, maps document names to template names.
html_sidebars = {
Expand Down
18 changes: 13 additions & 5 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,35 @@ Here we list a changelog of MNE-connectivity.
Version 0.5 (Unreleased)
------------------------

...
This version has major changes in :func:`mne_connectivity.spectral_connectivity_time`. Several bugs are fixed, and the
function now computes static connectivity over time, as opposed to static connectivity over trials computed by :func:`mne_connectivity.spectral_connectivity_epochs`.

Enhancements
~~~~~~~~~~~~

-
- Add the ``PLI`` and ``wPLI`` methods in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`).
- Improve the documentation of :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`).
- Add the option to average connectivity across epochs and frequencies in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`).
- Select multitaper frequencies automatically in :func:`mne_connectivity.spectral_connectivity_time` similarly to :func:`mne_connectivity.spectral_connectivity_epochs` by `Santeri Ruuskanen`_ (:gh:`104`).

Bug
~~~

-
- When using the ``multitaper`` mode in :func:`mne_connectivity.spectral_connectivity_time`, average CSD over tapers instead of the complex signal by `Santeri Ruuskanen`_ (:gh:`104`).
- Average over time when computing connectivity measures in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`).
- Fix support for multiple connectivity methods in calls to :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`).
- Fix bug with the ``indices`` parameter in :func:`mne_connectivity.spectral_connectivity_time`, the behavior is now as expected by `Santeri Ruuskanen`_ (:gh:`104`).
- Fix bug with parallel computation in :func:`mne_connectivity.spectral_connectivity_time`, add instructions for memory mapping in doc by `Santeri Ruuskanen`_ (:gh:`104`).

API
~~~

-
- Streamline the API of :func:`mne_connectivity.spectral_connectivity_time` with :func:`mne_connectivity.spectral_connectivity_epochs` by `Santeri Ruuskanen`_ (:gh:`104`).

Authors
~~~~~~~

*
* `Santeri Ruuskanen`_

:doc:`Find out what was new in previous releases <whats_new_previous_releases>`

Expand Down
4 changes: 2 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: mne
name: mne-connectivity
channels:
- conda-forge
dependencies:
Expand All @@ -20,5 +20,5 @@ dependencies:
- pyvista>=0.32
- pyvistaqt>=0.4
- pyqt!=5.15.3
- mne
- mne>=1.0
- h5netcdf
2 changes: 1 addition & 1 deletion mne_connectivity/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,6 @@ def _check_skip_backend(name):
if not has_imageio_ffmpeg():
pytest.skip("Test skipped, requires imageio-ffmpeg")
if name == 'pyvistaqt' and not _check_qt_version():
pytest.skip("Test skipped, requires PyQt5.")
pytest.skip("Test skipped, requires Python Qt bindings.")
if name == 'pyvistaqt' and not has_pyvistaqt():
pytest.skip("Test skipped, requires pyvistaqt")
13 changes: 8 additions & 5 deletions mne_connectivity/spectral/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,7 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None,
See Also
--------
mne_connectivity.spectral_connectivity_time
mne_connectivity.SpectralConnectivity
mne_connectivity.SpectroTemporalConnectivity
Expand All @@ -873,7 +874,9 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None,
connectivity structure. Within each Epoch, it is assumed that the spectral
measure is stationary. The spectral measures implemented in this function
are computed across Epochs. **Thus, spectral measures computed with only
one Epoch will result in errorful values.**
one Epoch will result in errorful values and spectral measures computed
with few Epochs will be unreliable.** Please see
``spectral_connectivity_time`` for time-resolved connectivity estimation.
The spectral densities can be estimated using a multitaper method with
digital prolate spheroidal sequence (DPSS) windows, a discrete Fourier
Expand All @@ -891,11 +894,11 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None,
indices = (np.array([0, 0, 0]), # row indices
np.array([2, 3, 4])) # col indices
con_flat = spectral_connectivity(data, method='coh',
indices=indices, ...)
con = spectral_connectivity_epochs(data, method='coh',
indices=indices, ...)
In this case con_flat.shape = (3, n_freqs). The connectivity scores are
in the same order as defined indices.
In this case con.get_data().shape = (3, n_freqs). The connectivity scores
are in the same order as defined indices.
**Supported Connectivity Measures**
Expand Down
210 changes: 104 additions & 106 deletions mne_connectivity/spectral/tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,8 @@
from numpy.testing import (assert_allclose, assert_array_almost_equal,
assert_array_less)
import pytest
import warnings

import mne
from mne import (EpochsArray, SourceEstimate, create_info,
make_fixed_length_epochs)
from mne import (EpochsArray, SourceEstimate, create_info)
from mne.filter import filter_data
from mne.utils import _resource_path
from mne_bids import BIDSPath, read_raw_bids

from mne_connectivity import (
SpectralConnectivity, spectral_connectivity_epochs,
Expand Down Expand Up @@ -478,15 +472,109 @@ def test_epochs_tmin_tmax(kind):
assert len(w) == 1 # just one even though there were multiple epochs


@pytest.mark.parametrize('method', ['coh', 'plv'])
@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
@pytest.mark.parametrize(
'mode', ['cwt_morlet', 'multitaper'])
@pytest.mark.parametrize('data_option', ['sync', 'random'])
def test_spectral_connectivity_time_phaselocked(method, mode, data_option):
"""Test time-resolved spectral connectivity with simulated phase-locked
data."""
rng = np.random.default_rng(0)
n_epochs = 5
n_channels = 3
n_times = 1000
sfreq = 250
data = np.zeros((n_epochs, n_channels, n_times))
if data_option == 'random':
# Data is random, there should be no consistent phase differences.
data = rng.random((n_epochs, n_channels, n_times))
if data_option == 'sync':
# Data consists of phase-locked 10Hz sine waves with constant phase
# difference within each epoch.
wave_freq = 10
epoch_length = n_times / sfreq
for i in range(n_epochs):
for c in range(n_channels):
phase = rng.random() * 10
x = np.linspace(-wave_freq * epoch_length * np.pi + phase,
wave_freq * epoch_length * np.pi + phase,
n_times)
data[i, c] = np.squeeze(np.sin(x))
# the frequency band should contain the frequency at which there is a
# hypothesized "connection"
freq_band_low_limit = (8.)
freq_band_high_limit = (13.)
cwt_freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1)
con = spectral_connectivity_time(data, method=method, mode=mode,
sfreq=sfreq, fmin=freq_band_low_limit,
fmax=freq_band_high_limit,
cwt_freqs=cwt_freqs, n_jobs=1,
faverage=True, average=True, sm_times=0)
assert con.shape == (n_channels ** 2, len(con.freqs))
con_matrix = con.get_data('dense')[..., 0]
if data_option == 'sync':
# signals are perfectly phase-locked, connectivity matrix should be
# a lower triangular matrix of ones
assert np.allclose(con_matrix,
np.tril(np.ones(con_matrix.shape),
k=-1),
atol=0.01)
if data_option == 'random':
# signals are random, all connectivity values should be small
# 0.5 is picked rather arbitrarily such that the obsolete wrong
# implementation fails
assert np.all(con_matrix) <= 0.5


@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
@pytest.mark.parametrize(
'cwt_freqs', [[8., 10.], [8, 10], 10., 10])
def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs):
"""Test time-resolved spectral connectivity with int and float values for
cwt_freqs."""
rng = np.random.default_rng(0)
n_epochs = 5
n_channels = 3
n_times = 1000
sfreq = 250
data = np.zeros((n_epochs, n_channels, n_times))

# Data consists of phase-locked 10Hz sine waves with constant phase
# difference within each epoch.
wave_freq = 10
epoch_length = n_times / sfreq
for i in range(n_epochs):
for c in range(n_channels):
phase = rng.random() * 10
x = np.linspace(-wave_freq * epoch_length * np.pi + phase,
wave_freq * epoch_length * np.pi + phase,
n_times)
data[i, c] = np.squeeze(np.sin(x))
# the frequency band should contain the frequency at which there is a
# hypothesized "connection"
con = spectral_connectivity_time(data, method=method, mode='cwt_morlet',
sfreq=sfreq, fmin=np.min(cwt_freqs),
fmax=np.max(cwt_freqs),
cwt_freqs=cwt_freqs, n_jobs=1,
faverage=True, average=True, sm_times=0)
assert con.shape == (n_channels ** 2, len(con.freqs))
con_matrix = con.get_data('dense')[..., 0]

# signals are perfectly phase-locked, connectivity matrix should be
# a lower triangular matrix of ones
assert np.allclose(con_matrix, np.tril(np.ones(con_matrix.shape), k=-1),
atol=0.01)


@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
@pytest.mark.parametrize(
'mode', ['cwt_morlet', 'multitaper'])
def test_spectral_connectivity_time_resolved(method, mode):
"""Test time-resolved spectral connectivity."""
sfreq = 50.
n_signals = 3
n_epochs = 2
n_times = 256
n_times = 1000
trans_bandwidth = 2.
tmin = 0.
tmax = (n_times - 1) / sfreq
Expand All @@ -502,22 +590,21 @@ def test_spectral_connectivity_time_resolved(method, mode):

# define some frequencies for cwt
freqs = np.arange(3, 20.5, 1)
n_freqs = len(freqs)

# run connectivity estimation
con = spectral_connectivity_time(
data, freqs=freqs, method=method, mode=mode)
assert con.shape == (n_epochs, n_signals * 2, n_freqs, n_times)
data, sfreq=sfreq, cwt_freqs=freqs, method=method, mode=mode,
n_cycles=5)
assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs))
assert con.get_data(output='dense').shape == \
(n_epochs, n_signals, n_signals, n_freqs, n_times)

# average over time
conn_data = con.get_data(output='dense').mean(axis=-1)
conn_data = conn_data.mean(axis=-1)
(n_epochs, n_signals, n_signals, len(con.freqs))

# test the simulated signal
triu_inds = np.vstack(np.triu_indices(n_signals, k=1)).T

# average over frequencies
conn_data = con.get_data(output='dense').mean(axis=-1)

# the indices at which there is a correlation should be greater
# then the rest of the components
for epoch_idx in range(n_epochs):
Expand All @@ -526,95 +613,6 @@ def test_spectral_connectivity_time_resolved(method, mode):
for idx, jdx in triu_inds)


@pytest.mark.parametrize('method', ['coh', 'plv'])
@pytest.mark.parametrize(
'mode', ['morlet', 'multitaper'])
def test_time_resolved_spectral_conn_regression(method, mode):
"""Regression test against original implementation in Frites.
To see how the test dataset was generated, see
``benchmarks/single_epoch_conn.py``.
"""
test_file_path_str = str(_resource_path(
'mne_connectivity.tests',
f'data/test_frite_dataset_{mode}_{method}.npy'))
test_conn = np.load(test_file_path_str)

# paths to mne datasets - sample ECoG
bids_root = mne.datasets.epilepsy_ecog.data_path()

# first define the BIDS path and load in the dataset
bids_path = BIDSPath(root=bids_root, subject='pt1', session='presurgery',
task='ictal', datatype='ieeg', extension='.vhdr')
with warnings.catch_warnings():
warnings.simplefilter("ignore")
raw = read_raw_bids(bids_path=bids_path, verbose=False)
line_freq = raw.info['line_freq']

# Pick only the ECoG channels, removing the ECG channels
raw.pick_types(ecog=True)

# drop bad channels
raw.drop_channels(raw.info['bads'])

# only pick the first three channels to lower RAM usage
raw = raw.pick_channels(raw.ch_names[:3])

# Load the data
raw.load_data()

# Then we remove line frequency interference
raw.notch_filter(line_freq)

# crop data and then Epoch
raw_copy = raw.copy()
raw = raw.crop(tmin=0, tmax=4, include_tmax=False)
epochs = make_fixed_length_epochs(raw=raw, duration=2., overlap=1.)

######################################################################
# Perform basic test to match simulation data using time-resolved spec
######################################################################
# compare data to original run using Frites
freqs = [30, 90]

# mode was renamed in mne-connectivity
if mode == 'morlet':
mode = 'cwt_morlet'
conn = spectral_connectivity_time(
epochs, freqs=freqs, n_jobs=1, method=method, mode=mode)

# frites only stores the upper triangular parts of the raveled array
row_triu_inds, col_triu_inds = np.triu_indices(len(raw.ch_names), k=1)
conn_data = conn.get_data(output='dense')[
:, row_triu_inds, col_triu_inds, ...]
assert_array_almost_equal(conn_data, test_conn)

######################################################################
# Give varying set of frequency bands and frequencies to perform cWT
######################################################################
raw = raw_copy.crop(tmin=0, tmax=10, include_tmax=False)
ch_names = epochs.ch_names
epochs = make_fixed_length_epochs(raw=raw, duration=5, overlap=0.)

# sampling rate of my data
sfreq = raw.info['sfreq']

# frequency bands of interest
fois = np.array([[4, 8], [8, 12], [12, 16], [16, 32]])

# frequencies of Continuous Morlet Wavelet Transform
freqs = np.arange(4., 32., 1)

# compute coherence
cohs = spectral_connectivity_time(
epochs, names=None, method=method, indices=None,
sfreq=sfreq, foi=fois, sm_times=0.5, sm_freqs=1, sm_kernel='hanning',
mode=mode, mt_bandwidth=None, freqs=freqs, n_cycles=5)
assert cohs.get_data(output='dense').shape == (
len(epochs), len(ch_names), len(ch_names), len(fois), len(epochs.times)
)


def test_save(tmp_path):
"""Test saving results of spectral connectivity."""
rng = np.random.RandomState(0)
Expand Down
Loading

0 comments on commit 4476efa

Please sign in to comment.