Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add connectivity simulation function #173

Merged
merged 12 commits into from
Mar 18, 2024
12 changes: 11 additions & 1 deletion doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,14 @@ Visualization functions
:toctree: generated/

plot_sensors_connectivity
plot_connectivity_circle
plot_connectivity_circle

Dataset functions
=================

.. currentmodule:: mne_connectivity

.. autosummary::
:toctree: generated/

make_signals_in_freq_bands
1 change: 1 addition & 0 deletions mne_connectivity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
SpectroTemporalConnectivity,
TemporalConnectivity,
)
from .datasets import make_signals_in_freq_bands
from .effective import phase_slope_index
from .envelope import envelope_correlation, symmetric_orth
from .io import read_connectivity
Expand Down
1 change: 1 addition & 0 deletions mne_connectivity/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .frequency import make_signals_in_freq_bands
146 changes: 146 additions & 0 deletions mne_connectivity/datasets/frequency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Authors: Adam Li <[email protected]>
# Thomas S. Binns <[email protected]>
#
# License: BSD (3-clause)

import numpy as np
from mne import EpochsArray, create_info
from mne.filter import filter_data


def make_signals_in_freq_bands(
n_seeds,
n_targets,
freq_band,
n_epochs=10,
n_times=200,
sfreq=100,
trans_bandwidth=1,
snr=0.7,
connection_delay=5,
tmin=0,
ch_names=None,
ch_types="eeg",
rng_seed=None,
):
"""Simulate signals interacting in a given frequency band.

Parameters
----------
n_seeds : int
Number of seed channels to simulate.
n_targets : int
Number of target channels to simulate.
freq_band : tuple of int or float
Frequency band where the connectivity should be simulated, where the first entry
corresponds to the lower frequency, and the second entry to the higher
frequency.
n_epochs : int (default 10)
Number of epochs in the simulated data.
n_times : int (default 200)
Number of timepoints each epoch of the simulated data.
sfreq : int | float (default 100)
Sampling frequency of the simulated data, in Hz.
trans_bandwidth : int | float (default 1)
Transition bandwidth of the filter to apply to isolate activity in
``freq_band``, in Hz.
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
snr : float (default 0.7)
Signal-to-noise ratio of the simulated data in the range [0, 1].
connection_delay : int (default 5)
Number of timepoints for the delay of connectivity between the seeds and
targets. If > 0, the target data is a delayed form of the seed data. If < 0, the
seed data is a delayed form of the target data.
tmin : int | float (default 0)
Earliest time of each epoch.
ch_names : list of str | None (default None)
Names of the channels in the simulated data. If `None`, the channels are named
according to their index and the frequency band of interaction.
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
ch_types : str | list of str (default "eeg")
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
Types of the channels in the simulated data.
rng_seed : int | None (default None)
Seed to use for the random number generator. If `None`, no seed is specified.

Returns
-------
epochs : mne.EpochsArray
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
The simulated data stored in an `mne.EpochsArray` object. The channels are
arranged according to seeds, then targets.

Notes
-----
Signals are simulated as a single source of activity in the given frequency band and
projected into ``n_seeds + n_targets`` noise channels.
"""
n_channels = n_seeds + n_targets

# check inputs
if n_seeds < 1 or n_targets < 1:
raise ValueError("Number of seeds and targets must each be at least 1.")

if not isinstance(freq_band, tuple):
raise TypeError("Frequency band must be a tuple.")
if len(freq_band) != 2:
raise ValueError("Frequency band must contain two numbers.")

if n_times < 1:
raise ValueError("Number of timepoints must be at least 1.")

if n_epochs < 1:
raise ValueError("Number of epochs must be at least 1.")

if sfreq <= 0:
raise ValueError("Sampling frequency must be > 0.")

if snr < 0 or snr > 1:
raise ValueError("Signal-to-noise ratio must be between 0 and 1.")

if np.abs(connection_delay) >= n_epochs * n_times:
raise ValueError(
"Connection delay must be less than the total number of timepoints."
)

# simulate data
rng = np.random.RandomState(rng_seed)
tsbinns marked this conversation as resolved.
Show resolved Hide resolved

# simulate signal source at desired frequency band
signal = rng.randn(1, n_epochs * n_times + np.abs(connection_delay))
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
signal = filter_data(
data=signal,
sfreq=sfreq,
l_freq=freq_band[0],
h_freq=freq_band[1],
l_trans_bandwidth=trans_bandwidth,
h_trans_bandwidth=trans_bandwidth,
fir_design="firwin2",
)

# simulate noise for each channel
noise = rng.randn(n_channels, n_epochs * n_times + np.abs(connection_delay))
tsbinns marked this conversation as resolved.
Show resolved Hide resolved

# create data by projecting signal into each channel of noise
data = (signal * snr) + (noise * (1 - snr))

# shift data by desired delay and remove extra time
if connection_delay != 0:
if connection_delay > 0:
delay_chans = np.arange(n_seeds, n_channels) # delay targets
else:
delay_chans = np.arange(0, n_seeds) # delay seeds
data[delay_chans, np.abs(connection_delay) :] = data[
delay_chans, : n_epochs * n_times
]
data = data[:, : n_epochs * n_times]

# reshape data into epochs
data = data.reshape(n_channels, n_epochs, n_times)
data = data.transpose((1, 0, 2)) # (epochs x channels x times)

# store data in an MNE EpochsArray object
if ch_names is None:
ch_names = [
f"{ch_i}_{freq_band[0]}_{freq_band[1]}" for ch_i in range(n_channels)
]
info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
epochs = EpochsArray(data=data, info=info, tmin=tmin)

return epochs
169 changes: 169 additions & 0 deletions mne_connectivity/tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import numpy as np
import pytest

from mne_connectivity import (
make_signals_in_freq_bands,
seed_target_indices,
spectral_connectivity_epochs,
)


@pytest.mark.parametrize("n_seeds", [1, 3])
@pytest.mark.parametrize("n_targets", [1, 3])
@pytest.mark.parametrize("snr", [0.7, 0.4])
@pytest.mark.parametrize("connection_delay", [0, 3, -3])
@pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"])
def test_make_signals_in_freq_bands(n_seeds, n_targets, snr, connection_delay, mode):
"""Test `make_signals_in_freq_bands` simulates connectivity properly."""
# Case with no spurious correlations (avoids tests randomly failing)
rng_seed = 0

# Simulate data
freq_band = (5, 10) # fmin, fmax (Hz)
sfreq = 100 # Hz
trans_bandwidth = 1 # Hz
data = make_signals_in_freq_bands(
n_seeds=n_seeds,
n_targets=n_targets,
freq_band=freq_band,
n_epochs=30,
n_times=200,
sfreq=sfreq,
trans_bandwidth=trans_bandwidth,
snr=snr,
connection_delay=connection_delay,
rng_seed=rng_seed,
)

# Compute connectivity
methods = ["coh", "imcoh", "dpli"]
indices = seed_target_indices(
seeds=np.arange(n_seeds), targets=np.arange(n_targets) + n_seeds
)
fmin = 3
fmax = sfreq // 2
if mode == "cwt_morlet":
cwt_params = {"cwt_freqs": np.arange(fmin, fmax), "cwt_n_cycles": 3.5}
else:
cwt_params = dict()
con = spectral_connectivity_epochs(
data,
method=methods,
indices=indices,
mode=mode,
fmin=fmin,
fmax=fmax,
**cwt_params,
)
freqs = np.array(con[0].freqs)

# Define expected connectivity values
thresh_good = dict()
thresh_bad = dict()
# Coh
thresh_good["coh"] = (0.2, 0.9)
thresh_bad["coh"] = (0.0, 0.2)
# ImCoh
if connection_delay == 0:
thresh_good["imcoh"] = (0.0, 0.2)
thresh_bad["imcoh"] = (0.0, 0.2)
else:
thresh_good["imcoh"] = (0.2, 0.8)
thresh_bad["imcoh"] = (0.0, 0.2)
# DPLI
if connection_delay == 0:
thresh_good["dpli"] = (0.3, 0.6)
thresh_bad["dpli"] = (0.3, 0.6)
elif connection_delay > 0:
thresh_good["dpli"] = (0.5, 1)
thresh_bad["dpli"] = (0.3, 0.6)
else:
thresh_good["dpli"] = (0, 0.5)
thresh_bad["dpli"] = (0.3, 0.6)

# Check connectivity values are acceptable
freqs_good = np.argwhere(
(freqs >= freq_band[0]) & (freqs <= freq_band[1])
).flatten()
freqs_bad = np.argwhere(
(freqs < freq_band[0] - trans_bandwidth * 2)
| (freqs > freq_band[1] + trans_bandwidth * 2)
).flatten()
for method_name, method_con in zip(methods, con):
con_values = method_con.get_data()
if method_name == "imcoh":
con_values = np.abs(con_values)
# freq. band of interest
con_values_good = np.mean(con_values[:, freqs_good])
assert (
con_values_good >= thresh_good[method_name][0]
and con_values_good <= thresh_good[method_name][1]
)

# other freqs.
con_values_bad = np.mean(con_values[:, freqs_bad])
assert (
con_values_bad >= thresh_bad[method_name][0]
and con_values_bad <= thresh_bad[method_name][1]
)


def test_make_signals_error_catch():
"""Test error catching for `make_signals_in_freq_bands`."""
freq_band = (5, 10)

# check bad n_seeds/targets caught
with pytest.raises(
ValueError, match="Number of seeds and targets must each be at least 1."
):
make_signals_in_freq_bands(n_seeds=0, n_targets=1, freq_band=freq_band)
with pytest.raises(
ValueError, match="Number of seeds and targets must each be at least 1."
):
make_signals_in_freq_bands(n_seeds=1, n_targets=0, freq_band=freq_band)

# check bad freq_band caught
with pytest.raises(TypeError, match="Frequency band must be a tuple."):
make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=1)
with pytest.raises(ValueError, match="Frequency band must contain two numbers."):
make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=(1, 2, 3))

# check bad n_times
with pytest.raises(ValueError, match="Number of timepoints must be at least 1."):
make_signals_in_freq_bands(
n_seeds=1, n_targets=1, freq_band=freq_band, n_times=0
)

# check bad n_epochs
with pytest.raises(ValueError, match="Number of epochs must be at least 1."):
make_signals_in_freq_bands(
n_seeds=1, n_targets=1, freq_band=freq_band, n_epochs=0
)

# check bad sfreq
with pytest.raises(ValueError, match="Sampling frequency must be > 0."):
make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, sfreq=0)

# check bad snr
with pytest.raises(
ValueError, match="Signal-to-noise ratio must be between 0 and 1."
):
make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, snr=-1)
with pytest.raises(
ValueError, match="Signal-to-noise ratio must be between 0 and 1."
):
make_signals_in_freq_bands(n_seeds=1, n_targets=1, freq_band=freq_band, snr=2)

# check bad connection_delay
with pytest.raises(
ValueError,
match="Connection delay must be less than the total number of timepoints.",
):
make_signals_in_freq_bands(
n_seeds=1,
n_targets=1,
freq_band=freq_band,
n_epochs=1,
n_times=1,
connection_delay=1,
)
Loading