Skip to content

Commit

Permalink
Merge pull request mne-tools#142 from tsbinns/pr-mvc_padding
Browse files Browse the repository at this point in the history
[ENH] Add support for ragged connections with multivariate methods with padding
  • Loading branch information
drammock authored Nov 3, 2023
2 parents c33d491 + 3e81252 commit 8bebc62
Show file tree
Hide file tree
Showing 15 changed files with 1,102 additions and 320 deletions.
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Post-processing on connectivity

degree
seed_target_indices
seed_target_multivariate_indices
check_indices
select_order

Expand Down
3 changes: 2 additions & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@
'n_node_names', 'n_tapers', 'n_signals', 'n_step', 'n_freqs',
'epochs', 'freqs', 'times', 'arrays', 'lists', 'func', 'n_nodes',
'n_estimated_nodes', 'n_samples', 'n_channels', 'Renderer',
'n_ytimes', 'n_ychannels', 'n_events'
'n_ytimes', 'n_ychannels', 'n_events', 'n_cons', 'max_n_chans',
'n_unique_seeds', 'n_unique_targets', 'variable'
}
numpydoc_xref_aliases = {
# Python
Expand Down
46 changes: 18 additions & 28 deletions examples/granger_causality.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@

###############################################################################
# We will focus on connectivity between sensors over the parietal and occipital
# cortices, with 20 parietal sensors designated as group A, and 20 occipital
# cortices, with 20 parietal sensors designated as group A, and 22 occipital
# sensors designated as group B.

# %%
Expand All @@ -157,17 +157,8 @@
signals_b = [idx for idx, ch_info in enumerate(epochs.info['chs']) if
ch_info['ch_name'][2] == 'O']

# XXX: Currently ragged indices are not supported, so we only consider a single
# list of indices with an equal number of seeds and targets
min_n_chs = min(len(signals_a), len(signals_b))
signals_a = signals_a[:min_n_chs]
signals_b = signals_b[:min_n_chs]

indices_ab = (np.array(signals_a), np.array(signals_b)) # A => B
indices_ba = (np.array(signals_b), np.array(signals_a)) # B => A

signals_a_names = [epochs.info['ch_names'][idx] for idx in signals_a]
signals_b_names = [epochs.info['ch_names'][idx] for idx in signals_b]
indices_ab = (np.array([signals_a]), np.array([signals_b])) # A => B
indices_ba = (np.array([signals_b]), np.array([signals_a])) # B => A

# compute Granger causality
gc_ab = spectral_connectivity_epochs(
Expand All @@ -181,8 +172,8 @@

###############################################################################
# Plotting the results, we see that there is a flow of information from our
# parietal sensors (group A) to our occipital sensors (group B) with noticeable
# peaks at around 8, 18, and 26 Hz.
# parietal sensors (group A) to our occipital sensors (group B) with a
# noticeable peak at ~8 Hz, and smaller peaks at 18 and 26 Hz.

# %%

Expand All @@ -208,8 +199,7 @@
#
# Doing so, we see that the flow of information across the spectrum remains
# dominant from parietal to occipital sensors (indicated by the positive-valued
# Granger scores). However, the pattern of connectivity is altered, such as
# around 10 and 12 Hz where peaks of net information flow are now present.
# Granger scores), with similar peaks around 10, 18, and 26 Hz.

# %%

Expand Down Expand Up @@ -289,8 +279,8 @@
# Plotting the TRGC results, reveals a very different picture compared to net
# GC. For one, there is now a dominance of information flow ~6 Hz from
# occipital to parietal sensors (indicated by the negative-valued Granger
# scores). Additionally, the peaks ~10 Hz are less dominant in the spectrum,
# with parietal to occipital information flow between 13-20 Hz being much more
# scores). Additionally, the peak ~10 Hz is less dominant in the spectrum, with
# parietal to occipital information flow between 13-20 Hz being much more
# prominent. The stark difference between net GC and TRGC results indicates
# that the net GC spectrum was contaminated by spurious connectivity resulting
# from source mixing or correlated noise in the recordings. Altogether, the use
Expand Down Expand Up @@ -353,21 +343,21 @@
# an automatic rank computation is performed and an appropriate degree of
# dimensionality reduction will be enforced. The rank of the data is determined
# by computing the singular values of the data and finding those within a
# factor of :math:`1e^{-10}` relative to the largest singular value.
# factor of :math:`1e^{-6}` relative to the largest singular value.
#
# In some circumstances, this threshold may be too lenient, in which case you
# should inspect the singular values of your data to identify an appropriate
# degree of dimensionality reduction to perform, which you can then specify
# manually using the ``rank`` argument. The code below shows one possible
# approach for finding an appropriate rank of close-to-singular data with a
# more conservative threshold of :math:`1e^{-5}`.
# Whilst unlikely, there may be scenarios in which this threshold may be too
# lenient. In these cases, you should inspect the singular values of your data
# to identify an appropriate degree of dimensionality reduction to perform,
# which you can then specify manually using the ``rank`` argument. The code
# below shows one possible approach for finding an appropriate rank of
# close-to-singular data with a more conservative threshold.

# %%

# gets the singular values of the data
s = np.linalg.svd(raw.get_data(), compute_uv=False)
# finds how many singular values are "close" to the largest singular value
rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the "closeness" criteria
# finds how many singular values are 'close' to the largest singular value
rank = np.count_nonzero(s >= s[0] * 1e-4) # 1e-4 is the 'closeness' criteria

###############################################################################
# Nonethless, even in situations where you specify an appropriate rank, it is
Expand All @@ -387,7 +377,7 @@
try:
spectral_connectivity_epochs(
epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=None,
gc_n_lags=20) # A => B
gc_n_lags=20, verbose=False) # A => B
print('Success!')
except RuntimeError as error:
print('\nCaught the following error:\n' + repr(error))
Expand Down
151 changes: 151 additions & 0 deletions examples/handling_ragged_arrays.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""
=========================================================
Working with ragged indices for multivariate connectivity
=========================================================
This example demonstrates how multivariate connectivity involving different
numbers of seeds and targets can be handled in MNE-Connectivity.
"""

# Author: Thomas S. Binns <[email protected]>
# License: BSD (3-clause)

# %%

import numpy as np

from mne_connectivity import spectral_connectivity_epochs

###############################################################################
# Background
# ----------
#
# With multivariate connectivity, interactions between multiple signals can be
# considered together, and the number of signals designated as seeds and
# targets does not have to be equal within or across connections. Issues can
# arise from this when storing information associated with connectivity in
# arrays, as the number of entries within each dimension can vary within and
# across connections depending on the number of seeds and targets. Such arrays
# are 'ragged', and support for ragged arrays is limited in NumPy to the
# ``object`` datatype. Not only is working with ragged arrays is cumbersome,
# but saving arrays with ``dtype='object'`` is not supported by the h5netcdf
# engine used to save connectivity objects. The workaround used in
# MNE-Connectivity is to pad ragged arrays with some known values according to
# the largest number of entries in each dimension, such that there is an equal
# amount of information across and within connections for each dimension of the
# arrays.
#
# As an example, consider we have 5 channels and want to compute 2 connections:
# the first between channels in indices 0 and 1 with those in indices 2, 3,
# and 4; and the second between channels 0, 1, 2, and 3 with channel 4. The
# seed and target indices can be written as such::
#
# seeds = [[0, 1 ], [0, 1, 2, 3]]
# targets = [[2, 3, 4], [4 ]]
#
# The ``indices`` parameter passed to
# :func:`~mne_connectivity.spectral_connectivity_epochs` and
# :func:`~mne_connectivity.spectral_connectivity_time` must be a tuple of
# array-likes, meaning
# that the indices can be passed as a tuple of: lists; tuples; or NumPy arrays.
# Examples of how ``indices`` can be formed are shown below::
#
# # tuple of lists
# ragged_indices = ([[0, 1 ], [0, 1, 2, 3]],
# [[2, 3, 4], [4 ]])
#
# # tuple of tuples
# ragged_indices = (((0, 1 ), (0, 1, 2, 3)),
# ((2, 3, 4), (4 )))
#
# # tuple of arrays
# ragged_indices = (np.array([[0, 1 ], [0, 1, 2, 3]], dtype='object'),
# np.array([[2, 3, 4], [4 ]], dtype='object'))
#
# **N.B. Note that since NumPy v1.19.0, dtype='object' must be specified when
# forming ragged arrays.**
#
# Just as for bivariate connectivity, the length of ``indices[0]`` and
# ``indices[1]`` is equal (i.e. the number of connections), however information
# about the multiple channel indices for each connection is stored in a nested
# array. Importantly, these indices are ragged, as the first connection will be
# computed between 2 seed and 3 target channels, and the second connection
# between 4 seed and 1 target channel(s). The connectivity functions will
# recognise the indices as being ragged, and pad them to a 'full' array by
# adding placeholder values which are masked accordingly. This makes the
# indices easier to work with, and also compatible with the engine used to save
# connectivity objects. For example, the above indices would become::
#
# padded_indices = (np.array([[0, 1, --, --], [0, 1, 2, 3]]),
# np.array([[2, 3, 4, --], [4, --, --, --]]))
#
# where ``--`` are masked entries. These indices are what is stored in the
# returned connectivity objects.
#
# For the connectivity results themselves, the methods available in
# MNE-Connectivity combine information across the different channels into a
# single (time-)frequency-resolved connectivity spectrum, regardless of the
# number of seed and target channels, so ragged arrays are not a concern here.
# However, the maximised imaginary part of coherency (MIC) method also returns
# spatial patterns of connectivity, which show the contribution of each channel
# to the dimensionality-reduced connectivity estimate (explained in more detail
# in :doc:`mic_mim`). Because these patterns are returned for each channel,
# their shape can vary depending on the number of seeds and targets in each
# connection, making them ragged. To avoid this, the patterns are padded along
# the channel axis with the known and invalid entry ``np.nan``, in line with
# that applied to ``indices``. Extracting only the valid spatial patterns from
# the connectivity object is trivial, as shown below:

# %%

# create random data
data = np.random.randn(10, 5, 200) # epochs x channels x times
sfreq = 50
ragged_indices = ([[0, 1], [0, 1, 2, 3]], # seeds
[[2, 3, 4], [4]]) # targets

# compute connectivity
con = spectral_connectivity_epochs(
data, method='mic', indices=ragged_indices, sfreq=sfreq, fmin=10, fmax=30,
verbose=False)
patterns = np.array(con.attrs['patterns'])
padded_indices = con.indices
n_freqs = con.get_data().shape[-1]
n_cons = len(ragged_indices[0])
max_n_chans = max(
len(inds) for inds in ([*ragged_indices[0], *ragged_indices[1]]))

# show that the padded indices entries are masked
assert np.sum(padded_indices[0][0].mask) == 2 # 2 padded channels
assert np.sum(padded_indices[1][0].mask) == 1 # 1 padded channels
assert np.sum(padded_indices[0][1].mask) == 0 # 0 padded channels
assert np.sum(padded_indices[1][1].mask) == 3 # 3 padded channels

# patterns have shape [seeds/targets x cons x max channels x freqs (x times)]
assert patterns.shape == (2, n_cons, max_n_chans, n_freqs)

# show that the padded patterns entries are all np.nan
assert np.all(np.isnan(patterns[0, 0, 2:])) # 2 padded channels
assert np.all(np.isnan(patterns[1, 0, 3:])) # 1 padded channels
assert not np.any(np.isnan(patterns[0, 1])) # 0 padded channels
assert np.all(np.isnan(patterns[1, 1, 1:])) # 3 padded channels

# extract patterns for first connection using the ragged indices
seed_patterns_con1 = patterns[0, 0, :len(ragged_indices[0][0])]
target_patterns_con1 = patterns[1, 0, :len(ragged_indices[1][0])]

# extract patterns for second connection using the padded, masked indices
seed_patterns_con2 = (
patterns[0, 1, :padded_indices[0][1].count()])
target_patterns_con2 = (
patterns[1, 1, :padded_indices[1][1].count()])

# show that shapes of patterns are correct
assert seed_patterns_con1.shape == (2, n_freqs) # channels (0, 1)
assert target_patterns_con1.shape == (3, n_freqs) # channels (2, 3, 4)
assert seed_patterns_con2.shape == (4, n_freqs) # channels (0, 1, 2, 3)
assert target_patterns_con2.shape == (1, n_freqs) # channels (4)

print('Assertions completed successfully!')

# %%
Loading

0 comments on commit 8bebc62

Please sign in to comment.