Skip to content

Commit

Permalink
switched to masked indices for multivariate conn
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Nov 2, 2023
1 parent cd69d65 commit d6ba398
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 102 deletions.
39 changes: 18 additions & 21 deletions examples/handling_ragged_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,28 +62,25 @@
# ragged_indices = (np.array([[0, 1 ], [0, 1, 2, 3]], dtype='object'),
# np.array([[2, 3, 4], [4 ]], dtype='object'))
#
# **N.B. Note that when forming ragged arrays in NumPy, dtype='object' must be
# specified.**
# **N.B. Note that since NumPy v1.19.0, dtype='object' must be specified when
# forming ragged arrays.**
#
# Just as for bivariate connectivity, the length of ``indices[0]`` and
# ``indices[1]`` is equal (i.e. the number of connections), however information
# about the multiple channel indices for each connection is stored in a nested
# array. Importantly, these indices are ragged, as the first connection will be
# computed between 2 seed and 3 target channels, and the second connection
# between 4 seed and 1 target channel. The connectivity functions will
# recognise the indices as being ragged, and pad them accordingly to make them
# easier to work with and compatible with the h5netcdf saving engine. The known
# value used to pad the arrays is ``-1``, an invalid channel index. The above
# indices would be padded to::
# recognise the indices as being ragged, and pad them to a 'full' array by
# adding placeholder values which are masked accordingly. This makes the
# indices easier to work with, and also compatible with the engine used to save
# connectivity objects. For example, the above indices would become::
#
# padded_indices = (np.array([[0, 1, -1, -1], [0, 1, 2, 3]]),
# np.array([[2, 3, 4, -1], [4, -1, -1, -1]]))
# padded_indices = (np.array([[0, 1, --, --], [0, 1, 2, 3]]),
# np.array([[2, 3, 4, --], [4, --, --, --]]))
#
# These indices are what is stored in the connectivity object, and is also the
# format of indices returned from the helper functions
# :func:`~mne_connectivity.check_multivariate_indices` and
# :func:`~mne_connectivity.seed_target_multivariate_indices`. It is also
# possible to pass the padded indices to the connectivity functions directly.
# where ``--`` are masked entries. These indices are what is stored in the
# returned connectivity objects.
#
# For the connectivity results themselves, the methods available in
# MNE-Connectivity combine information across the different channels into a
Expand Down Expand Up @@ -118,11 +115,11 @@
max_n_chans = max(
len(inds) for inds in ([*ragged_indices[0], *ragged_indices[1]]))

# show that the padded indices entries are all -1
assert np.count_nonzero(padded_indices[0][0] == -1) == 2 # 2 padded channels
assert np.count_nonzero(padded_indices[1][0] == -1) == 1 # 1 padded channels
assert np.count_nonzero(padded_indices[0][1] == -1) == 0 # 0 padded channels
assert np.count_nonzero(padded_indices[1][1] == -1) == 3 # 3 padded channels
# show that the padded indices entries are masked
assert np.sum(padded_indices[0][0].mask) == 2 # 2 padded channels
assert np.sum(padded_indices[1][0].mask) == 1 # 1 padded channels
assert np.sum(padded_indices[0][1].mask) == 0 # 0 padded channels
assert np.sum(padded_indices[1][1].mask) == 3 # 3 padded channels

# patterns have shape [seeds/targets x cons x max channels x freqs (x times)]
assert patterns.shape == (2, n_cons, max_n_chans, n_freqs)
Expand All @@ -137,11 +134,11 @@
seed_patterns_con1 = patterns[0, 0, :len(ragged_indices[0][0])]
target_patterns_con1 = patterns[1, 0, :len(ragged_indices[1][0])]

# extract patterns for second connection using the padded indices (pad = -1)
# extract patterns for second connection using the padded, masked indices
seed_patterns_con2 = (
patterns[0, 1, :np.count_nonzero(padded_indices[0][1] != -1)])
patterns[0, 1, :padded_indices[0][1].count()])
target_patterns_con2 = (
patterns[1, 1, :np.count_nonzero(padded_indices[1][1] != -1)])
patterns[1, 1, :padded_indices[1][1].count()])

# show that shapes of patterns are correct
assert seed_patterns_con1.shape == (2, n_freqs) # channels (0, 1)
Expand Down
6 changes: 4 additions & 2 deletions mne_connectivity/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ def _xarray_to_conn(array, cls_func):
event_id = dict(zip(event_id_keys, event_id_vals))
array.attrs['event_id'] = event_id

# convert indices numpy arrays to a tuple of arrays
# convert indices numpy arrays to a tuple of masked arrays
# (only multivariate connectivity indices saved as arrays)
if isinstance(array.attrs['indices'], np.ndarray):
array.attrs['indices'] = tuple(array.attrs['indices'])
array.attrs['indices'] = tuple(
np.ma.masked_values(array.attrs['indices'], -1))

# create the connectivity class
conn = cls_func(
Expand Down
15 changes: 2 additions & 13 deletions mne_connectivity/spectral/tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,12 +1315,9 @@ def test_multivar_save_load(tmp_path):
assert a == b


@pytest.mark.parametrize("method", ["coh", "plv", "pli", "wpli", "ciplv",
"mic", "mim"])
@pytest.mark.parametrize("method", ['coh', 'plv', 'pli', 'wpli', 'ciplv'])
@pytest.mark.parametrize("indices", [None,
(np.array([0, 1]), np.array([2, 3])),
(np.array([[0, 1]]), np.array([[2, 3]]))
])
(np.array([0, 1]), np.array([2, 3]))])
def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices):
"""Test that indices values and type is maintained after saving.
Expand All @@ -1337,14 +1334,6 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices):
freqs = np.arange(10, 31)
tmp_file = os.path.join(tmp_path, "foo_mvc.nc")

# mutlivariate and bivariate methods require the right indices shape
if method in ["mic", "mim"]:
if indices is not None and indices[0].ndim == 1:
pytest.skip()
else:
if indices is not None and indices[0].ndim == 2:
pytest.skip()

# test the pair of method and indices defined to check the output indices
con_epochs = spectral_connectivity_epochs(
epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30
Expand Down
15 changes: 7 additions & 8 deletions mne_connectivity/spectral/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,20 +422,19 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False,
indices_use])
np.ma.set_fill_value(indices_use, -1) # else 99999 after concat.
if any(this_method in _gc_methods for this_method in method):
for seed, target in zip(indices[0], indices[1]):
intersection = np.intersect1d(seed, target)
if np.any(intersection != -1): # ignore padded entries
for seed, target in zip(indices_use[0], indices_use[1]):
intersection = np.intersect1d(seed.compressed(),
target.compressed())
if intersection.size > 0:
raise ValueError(
'seed and target indices must not intersect when '
'computing Granger causality')
# make sure padded indices are stored in the connectivity object
indices = tuple(np.array(indices_use)) # create a copy
# create a copy
indices = (indices_use[0].copy(), indices_use[1].copy())
else:
indices_use = check_indices(indices)
# create copies of indices_use for independent manipulation
source_idx = np.array(indices_use[0])
target_idx = np.array(indices_use[1])
n_cons = len(source_idx)
n_cons = len(indices_use[0])

# unique signals for which we actually need to compute the CSD of
if multivariate_con:
Expand Down
16 changes: 12 additions & 4 deletions mne_connectivity/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,22 @@ def test_seed_target_indices():
seeds = [[0, 1]]
targets = [[2, 3], [3, 4]]
indices = seed_target_multivariate_indices(seeds, targets)
assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]),
np.array([[2, 3], [3, 4]])))
match_indices = (np.array([[0, 1], [0, 1]], dtype=object),
np.array([[2, 3], [3, 4]], dtype=object))
for type_i in range(2):
for con_i in range(len(indices[0])):
assert np.all(indices[type_i][con_i] ==
match_indices[type_i][con_i])
# ragged indices
seeds = [[0, 1]]
targets = [[2, 3, 4], [4]]
indices = seed_target_multivariate_indices(seeds, targets)
assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]),
np.array([[2, 3, 4], [4, -1, -1]])))
match_indices = (np.array([[0, 1], [0, 1]], dtype=object),
np.array([[2, 3, 4], [4]], dtype=object))
for type_i in range(2):
for con_i in range(len(indices[0])):
assert np.all(indices[type_i][con_i] ==
match_indices[type_i][con_i])
# test error catching
# non-array-like seeds/targets
with pytest.raises(TypeError,
Expand Down
105 changes: 51 additions & 54 deletions mne_connectivity/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _check_multivariate_indices(indices, n_chans):
Returns
-------
indices : tuple of array of array of int, shape of (2, n_cons, max_n_chans)
The indices padded with the invalid channel index ``-1``.
The indices as a masked array.
Notes
-----
Expand All @@ -112,26 +112,35 @@ def _check_multivariate_indices(indices, n_chans):
connection must be unique.
If the seed and target indices are given as lists or tuples, they will be
converted to numpy arrays. In case the number of channels differs across
converted to numpy arrays. Because the number of channels can differ across
connections or between the seeds and targets for a given connection (i.e.
ragged indices), the returned array will be padded with the invalid channel
index ``-1`` according to the maximum number of channels in the seed or
target of any one connection. E.g. the ragged indices of shape ``(2,
n_cons, variable)``::
ragged/jagged indices), the returned array will be padded out to a 'full'
array with an invalid index (``-1``) according to the maximum number of
channels in the seed or target of any one connection. These invalid
entries are then masked and returned as numpy masked arrays. E.g. the
ragged indices of shape ``(2, n_cons, variable)``::
indices = ([[0, 1], [0, 1 ]], # seeds
[[2, 3], [4, 5, 6]]) # targets
would be returned as::
would be padded to full arrays::
indices = ([[0, 1, -1], [0, 1, -1]], # seeds
[[2, 3, -1], [4, 5, 6]]) # targets
to have shape ``(2, n_cons, max_n_chans)``, where ``max_n_chans = 3``. The
invalid entries are then masked::
indices = (np.array([[0, 1, -1], [0, 1, -1]]), # seeds
np.array([[2, 3, -1], [4, 5, -1]])) # targets
indices = ([[0, 1, --], [0, 1, --]], # seeds
[[2, 3, --], [4, 5, 6]]) # targets
where the indices have been padded with ``-1`` to have shape ``(2, n_cons,
max_n_chans)``, where ``max_n_chans = 3``. More information on working with
multivariate indices and handling connections where the number of seeds and
targets are not equal can be found in the
:doc:`../auto_examples/handling_ragged_arrays` example.
In case "indices" contains negative values to index channels, these will be
converted to the corresponding positive-valued index before any masking is
applied.
More information on working with multivariate indices and handling
connections where the number of seeds and targets are not equal can be
found in the :doc:`../auto_examples/handling_ragged_arrays` example.
"""
if not isinstance(indices, tuple) or len(indices) != 2:
raise ValueError('indices must be a tuple of length 2')
Expand All @@ -141,6 +150,7 @@ def _check_multivariate_indices(indices, n_chans):
'have the same length')

n_cons = len(indices[0])
invalid = -1

max_n_chans = 0
for group_idx, group in enumerate(indices):
Expand All @@ -166,15 +176,19 @@ def _check_multivariate_indices(indices, n_chans):
indices[group_idx][con_idx][chan_idx] = chan % n_chans

# pad indices to avoid ragged arrays
padded_indices = (np.full((n_cons, max_n_chans), -1, dtype=np.int32),
np.full((n_cons, max_n_chans), -1, dtype=np.int32))
padded_indices = (np.full((n_cons, max_n_chans), invalid, dtype=np.int32),
np.full((n_cons, max_n_chans), invalid, dtype=np.int32))
con_i = 0
for seed, target in zip(indices[0], indices[1]):
padded_indices[0][con_i, :len(seed)] = seed
padded_indices[1][con_i, :len(target)] = target
con_i += 1

return padded_indices
# mask invalid indices
masked_indices = (np.ma.masked_values(padded_indices[0], invalid),
np.ma.masked_values(padded_indices[1], invalid))

return masked_indices


def seed_target_indices(seeds, targets):
Expand Down Expand Up @@ -236,8 +250,8 @@ def seed_target_multivariate_indices(seeds, targets):
Returns
-------
indices : tuple of array of array of int, shape (2, n_cons, max_n_chans)
The indices padded with the invalid channel index ``-1``.
indices : tuple of array of array of int, shape (2, n_cons, variable)
The indices as a numpy object array.
Notes
-----
Expand All @@ -247,28 +261,24 @@ def seed_target_multivariate_indices(seeds, targets):
channels in the data. The length of indices for each connection do not need
to be equal. Furthermore, all indices within a connection must be unique.
``seeds`` and ``targets`` will be expanded such that connectivity will be
computed between each set of seeds and targets. In case the number of
channels differs across connections or between the seeds and targets for a
given connection (i.e. ragged indices), the returned array will be padded
with the invalid channel index ``-1`` according to the maximum number of
channels in the seed or target of any one connection. E.g. ``seeds`` and
Because the number of channels per connection can vary, the indices are
stored as numpy arrays with ``dtype=object``. E.g. ``seeds`` and
``targets``::
seeds = [[0]]
targets = [[1, 2], [3, 4, 5]]
would be returned as::
indices = (np.array([[0, -1, -1], [0, -1, -1]]), # seeds
np.array([[1, 2, -1], [3, 4, 5]])) # targets
indices = (np.array([[0 ], [0 ]], dtype=object), # seeds
np.array([[1, 2], [3, 4, 5]], dtype=object)) # targets
Even if the number of channels does not vary, the indices will still be
stored as object arrays for compatibility.
where the indices have been padded with ``-1`` to have shape ``(2, n_cons,
max_n_chans)``, where ``n_cons = n_unique_seeds * n_unique_targets`` and
``max_n_chans = 3``. More information on working with multivariate indices
and handling connections where the number of seeds and targets are not
equal can be found in the :doc:`../auto_examples/handling_ragged_arrays`
example.
More information on working with multivariate indices and handling
connections where the number of seeds and targets are not equal can be
found in the :doc:`../auto_examples/handling_ragged_arrays` example.
"""
array_like = (np.ndarray, list, tuple)

Expand All @@ -278,35 +288,22 @@ def seed_target_multivariate_indices(seeds, targets):
):
raise TypeError('`seeds` and `targets` must be array-like')

n_chans = []
for inds in [*seeds, *targets]:
if not isinstance(inds, array_like):
raise TypeError(
'`seeds` and `targets` must contain nested array-likes')
if len(inds) != len(np.unique(inds)):
raise ValueError(
'`seeds` and `targets` cannot contain repeated channels')
n_chans.append(len(inds))
max_n_chans = max(n_chans)
n_cons = len(seeds) * len(targets)

# pad indices to avoid ragged arrays
padded_seeds = np.full((len(seeds), max_n_chans), -1, dtype=np.int32)
padded_targets = np.full((len(targets), max_n_chans), -1, dtype=np.int32)
for con_i, seed in enumerate(seeds):
padded_seeds[con_i, :len(seed)] = seed
for con_i, target in enumerate(targets):
padded_targets[con_i, :len(target)] = target

# create final indices
indices = (np.zeros((n_cons, max_n_chans), dtype=np.int32),
np.zeros((n_cons, max_n_chans), dtype=np.int32))
con_i = 0
for seed in padded_seeds:
for target in padded_targets:
indices[0][con_i] = seed
indices[1][con_i] = target
con_i += 1
indices = [[], []]
for seed in seeds:
for target in targets:
indices[0].append(np.array(seed))
indices[1].append(np.array(target))

indices = (np.array(indices[0], dtype=object),
np.array(indices[1], dtype=object))

return indices

Expand Down

0 comments on commit d6ba398

Please sign in to comment.