-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] Add support for ragged connections with multivariate methods with padding #142
Conversation
mne_connectivity/utils/utils.py
Outdated
Returns | ||
------- | ||
indices : tuple of array of array of int | ||
The indices padded with the invalid channel index ``-1``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is -1 ever going to be used?
E.g. -1
stands for usually the last index -1 position in Python. But in our context, is this fine? Or is np.nan better?
I'm jw. If I'm wrong, then that's fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that -1 is technically a valid index and np.nan would be best, however np.nan is a float and cannot be represented in an array of ints.
If we switched to an array of floats, np.nan could be represented, however numpy does not like floats as indices, so we would have to convert to ints every time we want to index channels.
One alternative is numpy masked arrays (https://numpy.org/doc/stable/reference/maskedarray.html) however I'm not sure how this would be handled when saving using h5netcdf. Perhaps another option could be to choose a very unlikely integer like -99,999.
Ultimately I would be happy to go with whatever you feel is most appropriate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll defer to you. I think it's something we can consider changing in the future if it's an issue for users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also find the padding with -1
to be unexpected. Without knowing the ramifications offhand, I would think that a masked array, a scipy sparse array, or possibly the pandas nullable integer dtype would be preferable. @tsbinns mentioned saving to netcdf as a possible constraint/limitation... will the tuple of seed/target indices themselves need to be saved out in that format? Or just the resultant connectivity objects (which presumably keeps the indices in a class attribute)? Asking because how we represent the indices could conceivably be different between what is given to the user by this function & passed by the user into the connectivity funcs, versus how they are stored internally inside the resulting connectivity objects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will try to review this in-depth next week
@adam2392 I updated the documentation for the util functions: |
indices = (np.array([[0, 1]]), np.array([[2]])) | ||
conn_mvc = spectral_connectivity_epochs( | ||
epochs, method="mic", indices=indices, sfreq=sfreq, fmin=10, fmax=40) | ||
conn_mvc.save(tmp_path / 'foo_mvc.nc') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move this to a separate test and also test roundtrip IO? I.e. reading in the file after saving results in everything being the same as what you would expect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good idea. I have added a new test for saving the multivariate results and checking their integrity:
mne-connectivity/mne_connectivity/spectral/tests/test_spectral.py
Lines 1286 to 1315 in 409c2c6
def test_multivar_save_load(tmp_path): | |
"""Test saving and loading results of multivariate connectivity.""" | |
rng = np.random.RandomState(0) | |
n_epochs, n_chs, n_times, sfreq, f = 10, 4, 2000, 1000., 20. | |
data = rng.randn(n_epochs, n_chs, n_times) | |
sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000) | |
data[:, :, 500:1500] += sig | |
info = create_info(n_chs, sfreq, 'eeg') | |
tmin = -1 | |
epochs = EpochsArray(data, info, tmin=tmin) | |
tmp_file = os.path.join(tmp_path, 'foo_mvc.nc') | |
non_ragged_indices = (np.array([[0, 1]]), np.array([[2, 3]])) | |
ragged_indices = (np.array([[0, 1]]), np.array([[2]])) | |
for indices in [non_ragged_indices, ragged_indices]: | |
con = spectral_connectivity_epochs( | |
epochs, method=['mic', 'mim', 'gc', 'gc_tr'], indices=indices, | |
sfreq=sfreq, fmin=10, fmax=30) | |
for this_con in con: | |
this_con.save(tmp_file) | |
read_con = read_connectivity(tmp_file) | |
assert_array_almost_equal(this_con.get_data(), | |
read_con.get_data('raveled')) | |
if this_con.attrs['patterns'] is not None: | |
assert_array_almost_equal(np.array(this_con.attrs['patterns']), | |
np.array(read_con.attrs['patterns'])) | |
# split `repr` before the file size (`~23 kB` for example) | |
a = repr(this_con).split('~')[0] | |
b = repr(read_con).split('~')[0] | |
assert a == b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When adding this test, I realised that read_connectivity
returns an object whose indices
attribute is an array, whereas it is a tuple of arrays/lists in the original object.
Because indices
is an array, a later check in _check_data_consistency
that indices
is a tuple fails (L538), and so it tries to check if the array is a string (L553), which leads to an elementwise comparison that fails. With my NumPy version (1.24.3) I get the following warning which causes pytest to fail: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison
, so this should cause a straight-up error in newer NumPy versions.
mne-connectivity/mne_connectivity/base.py
Lines 538 to 561 in 409c2c6
if isinstance(indices, tuple): | |
# check that the indices passed in are of the same length | |
if len(indices[0]) != len(indices[1]): | |
raise ValueError(f'If indices are passed in ' | |
f'then they must be the same ' | |
f'length. They are right now ' | |
f'{len(indices[0])} and ' | |
f'{len(indices[1])}.') | |
# indices length should match the data length | |
if len(indices[0]) != data_len: | |
raise ValueError( | |
f'The number of indices, {len(indices[0])} ' | |
f'should match the raveled data length passed ' | |
f'in of {data_len}.') | |
elif indices == 'symmetric': | |
expected_len = ((n_nodes + 1) * n_nodes) // 2 | |
if data_len != expected_len: | |
raise ValueError(f'If "indices" is "symmetric", then ' | |
f'connectivity data should be the ' | |
f'upper-triangular part of the matrix. There ' | |
f'are {data_len} estimated connections. ' | |
f'But there should be {expected_len} ' | |
f'estimated connections.') |
The same issue also occurs for the pre-existing bivariate connectivity measures, however since reading saved connectivity objects is only tested where indices=None
, this warning is never seen.
A solution is to convert indices
from an array to a tuple of arrays/lists when reading connectivity objects in _xarray_to_conn
(i.e. return it to the same format as it was in the original connectivity object).
mne-connectivity/mne_connectivity/io.py
Lines 56 to 58 in 409c2c6
# convert indices numpy arrays to a tuple of arrays | |
if isinstance(array.attrs['indices'], np.ndarray): | |
array.attrs['indices'] = tuple(array.attrs['indices']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm inclined for the soln. that changes existing code as little as possible. I think the soln. you posed of converting indices to an array of tuple of arrays/lists sounds good.
Can you also add a separate unit-test for this when indices != None
for existing methods and the new multivariate methods to demonstrate how this works?
Thanks for investigating!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test for checking that the type and values of indices
is retained after saving for: bivariate and multivariate connectivity; with spectral_connectivity_epochs
and spectral_connectivity_times
; with indices
= None and indices
!= None.
mne-connectivity/mne_connectivity/spectral/tests/test_spectral.py
Lines 1318 to 1360 in b7fcf12
def test_spectral_connectivity_indices_maintained(tmp_path): | |
"""Test that indices values and type is maintained after saving. | |
If `indices` is None, `indices` in the returned connectivity object should | |
be None, otherwise, `indices` should be a tuple. The type of `indices` and | |
its values should be retained after saving and reloading. | |
""" | |
rng = np.random.RandomState(0) | |
n_epochs, n_chs, n_times, sfreq, f = 5, 4, 2000, 1000., 20. | |
data = rng.randn(n_epochs, n_chs, n_times) | |
sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000) | |
data[:, :, 500:1500] += sig | |
info = create_info(n_chs, sfreq, 'eeg') | |
tmin = -1 | |
epochs = EpochsArray(data, info, tmin=tmin) | |
freqs = np.arange(10, 31) | |
tmp_file = os.path.join(tmp_path, 'foo_mvc.nc') | |
bivar_indices = (np.array([0, 1]), np.array([2, 3])) | |
multivar_indices = (np.array([[0, 1]]), np.array([[2, 3]])) | |
indices = [None, bivar_indices, None, multivar_indices] | |
methods = ['coh', 'coh', 'mic', 'mic'] | |
for this_indices, this_method in zip(indices, methods): | |
con_epochs = spectral_connectivity_epochs( | |
epochs, method=this_method, indices=this_indices, sfreq=sfreq, | |
fmin=10, fmax=30) | |
con_time = spectral_connectivity_time( | |
epochs, freqs, method=this_method, indices=this_indices, | |
sfreq=sfreq) | |
for con in [con_epochs, con_time]: | |
con.save(tmp_file) | |
read_con = read_connectivity(tmp_file) | |
if this_indices is not None: | |
# check indices of same type (tuples) | |
assert (isinstance(con.indices, tuple) and | |
isinstance(read_con.indices, tuple)) | |
# check indices have same values | |
assert np.all(np.array(con.indices) == | |
np.array(read_con.indices)) | |
else: | |
assert con.indices is None and read_con.indices is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a few questions on areas that seemed like potential code issues. Overall, I don't see any major issues and think this should be mergable soon.
FYI for making the reviewing process easier in general, is limiting the code diff wrt code that isn't relevant to this change.
Also, after this is merged, as we discussed in #125 are you interested in joining the maintenance team and getting core-dev rights to mne-connectivity?
@adam2392 Thanks very much for the detailed feedback; I have made some updates accordingly.
Yes please, I am very happy to help out and make sure this can be maintained! |
Signed-off-by: Adam Li <[email protected]>
I think this mostly LGTM, but I have a question regarding indices that I didn't spot earlier. The unit-test you added in
However, this test fails. Discussion on why unit-test failsWe have indices defined in two ways, which to me is not easy to remember:
I would expect that if someone forgets to pass in a 2D array, they accidentally perform bivariate indices. I'm jw if it's possible to infer this behavior in anyway? If not, I think the docstring can maybe be a bit more clear on what indices to expect. I.e. perhaps a
So IIUC, mic/mim will not work when bivariate indices are passed. And bivariate methods will not work when multivariate indices are passed. However, the current error messages are rather uninformative for bivariate methods:
The multivariate one looks good:
Lmk if there's anything I can clarify |
Sorry that the commit history is such a mess, but I have reverted all the refactoring changes. Everything now relates only to the support for multiple connections with the multivariate methods, with the use of masked arrays for handling ragged indices. I have also made mne-connectivity/mne_connectivity/utils/utils.py Lines 86 to 87 in d6ba398
I believe all points relating to how multivariate indices are handled have now been addressed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try: | ||
spectral_connectivity_epochs( | ||
epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=None, | ||
gc_n_lags=20) # A => B | ||
gc_n_lags=20, verbose=False) # A => B | ||
print('Success!') | ||
except RuntimeError as error: | ||
print('\nCaught the following error:\n' + repr(error)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want, you could add this to expected_failing_examples
then you don't have to try/except at all here. Just let it actually fail and sphinx-gallery will print a nicely formatted traceback for you.
https://sphinx-gallery.github.io/stable/configuration.html#dont-fail-exit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm not at the moment. Feel free to open an issue at https://github.com/sphinx-gallery/sphinx-gallery about adding this possibility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened! sphinx-gallery/sphinx-gallery#1220
If the others agree, I would stick with the try-except approach until this behaviour in sphinx gallery is changed, so as not to give the impression from the thumbnail that the entire example is about failing code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine as is already, no need to wait to merge for this. It's an easy follow up PR after this is merged and SG has the machinery it needs
connectivity. If a bivariate method is called, each array for the seeds | ||
and targets should contain the channel indices for each bivariate | ||
connection. If a multivariate method is called, each array for the | ||
seeds and targets should consist of nested arrays containing | ||
the channel indices for each multivariate connection. If ``None``, | ||
connections between all channels are computed, unless a Granger | ||
causality method is called, in which case an error is raised. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to discuss anything about the masking behavior? Or would we say that's an internal implementation detail?
If it's an internal implementation detail, I wonder if we can perhaps add inline comments where it first comes up about how the indices
variable changes into a mask.
I'm okay being vetoed here. Just a thought
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In terms of what is passed into the function, the masking is an internal feature. At present there is only a small inline comment documenting this:
mne-connectivity/mne_connectivity/spectral/epochs.py
Lines 118 to 119 in 883e816
# mask indices | |
indices_use = _check_multivariate_indices(indices, n_signals) |
... although the function where the masking happens is documented:
mne-connectivity/mne_connectivity/utils/utils.py
Lines 86 to 144 in 883e816
def _check_multivariate_indices(indices, n_chans): | |
"""Check indices parameter for multivariate connectivity and mask it. | |
Parameters | |
---------- | |
indices : tuple of array of array of int, shape (2, n_cons, variable) | |
Tuple containing index sets. | |
n_chans : int | |
The number of channels in the data. Used when converting negative | |
indices to positive indices. | |
Returns | |
------- | |
indices : tuple of array of array of int, shape of (2, n_cons, max_n_chans) | |
The indices as a masked array. | |
Notes | |
----- | |
Indices for multivariate connectivity should be a tuple of length 2 | |
containing the channel indices for the seed and target channel sets, | |
respectively. Seed and target indices should be equal-length array-likes | |
representing the indices of the channel sets in the data for each | |
connection. The indices for each connection should be an array-like of | |
integers representing the individual channels in the data. The length of | |
indices for each connection do not need to be equal. All indices within a | |
connection must be unique. | |
If the seed and target indices are given as lists or tuples, they will be | |
converted to numpy arrays. Because the number of channels can differ across | |
connections or between the seeds and targets for a given connection (i.e. | |
ragged/jagged indices), the returned array will be padded out to a 'full' | |
array with an invalid index (``-1``) according to the maximum number of | |
channels in the seed or target of any one connection. These invalid | |
entries are then masked and returned as numpy masked arrays. E.g. the | |
ragged indices of shape ``(2, n_cons, variable)``:: | |
indices = ([[0, 1], [0, 1 ]], # seeds | |
[[2, 3], [4, 5, 6]]) # targets | |
would be padded to full arrays:: | |
indices = ([[0, 1, -1], [0, 1, -1]], # seeds | |
[[2, 3, -1], [4, 5, 6]]) # targets | |
to have shape ``(2, n_cons, max_n_chans)``, where ``max_n_chans = 3``. The | |
invalid entries are then masked:: | |
indices = ([[0, 1, --], [0, 1, --]], # seeds | |
[[2, 3, --], [4, 5, 6]]) # targets | |
In case "indices" contains negative values to index channels, these will be | |
converted to the corresponding positive-valued index before any masking is | |
applied. | |
More information on working with multivariate indices and handling | |
connections where the number of seeds and targets are not equal can be | |
found in the :doc:`../auto_examples/handling_ragged_arrays` example. | |
""" |
I'm happy to add more comments directly in
epochs.py
and time.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of course, as soon as the connectivity object is returned, the masking is no longer just an internal implementation. The goal of this example was to provide some more context on this:
mne-connectivity/examples/handling_ragged_arrays.py
Lines 1 to 8 in 883e816
""" | |
========================================================= | |
Working with ragged indices for multivariate connectivity | |
========================================================= | |
This example demonstrates how multivariate connectivity involving different | |
numbers of seeds and targets can be handled in MNE-Connectivity. | |
""" |
However I realise adding some documentation within the connectivity classes themselves also makes sense, e.g. here:
mne-connectivity/mne_connectivity/utils/docs.py
Lines 36 to 43 in 883e816
docdict['indices'] = """ | |
indices : tuple of arrays | str | None | |
The indices of relevant connectivity data. If ``'all'`` (default), | |
then data is connectivity between all nodes. If ``'symmetric'``, | |
then data is symmetric connectivity between all nodes. If a tuple, | |
then the first list represents the "in nodes", and the second list | |
represents the "out nodes". See "Notes" for more information. | |
""" |
Happy to add this as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Expanded upon the inline comments per the suggestion.
- Was also able to clean up some code now that
_check_multivariate_indices
is private.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comment, but otw LGTM upon a quick glance! The implementation is much cleaner :)
Co-authored-by: Eric Larson <[email protected]>
Co-authored-by: Eric Larson <[email protected]>
…ctivity into pr-mvc_padding
In it goes! Thanks a million @tsbinns !! |
Amazing work @tsbinns ! |
[ENH] Add support for ragged connections with multivariate methods with padding
Following discussions in #125.
I have added support for handling 'ragged' connections (i.e. where the number of seeds and targets differs across and within connections) with the multivariate methods MIC, MIM, and state-space GC introduced in #138. This affects: the
indices
attribute; and thepatterns
attribute (associated with MIC, with values returned for each channel in the seeds and targets). The connectivity results themselves are unaffected due to the built-in dimensionality reduction, and so return a single connectivity spectrum regardless of the number of seeds and targets. I have written an example which explains the problem and the workaround in more detail as @adam2392 suggested:mne-connectivity/examples/handling_ragged_arrays.py
Lines 2 to 7 in 73792c3
In short: ragged arrays are difficult to work with and cannot be saved using the h5netcdf engine. To get around this problem we can pad the arrays for the missing entries with some known value to make them 'full'.
Padding
indices
For padding
indices
, I chose the value-1
. Because the entries need to be integers to be interpretable for indexing, we cannot represent values such asnp.nan
ornp.inf
in the arrays, so I picked-1
as an invalid channel index. E.g. say we want to look at two connections, the first between 2 seeds and 3 targets, and the second between 4 seeds and 1 target;indices
could have the form:ragged_seeds = [[0, 1 ], [5, 6, 7, 8]]
ragged_targets = [[2, 3, 4], [9 ]]
ragged_indices = (ragged_seeds, ragged_targets)
The ragged seeds/targets thus have shape
[cons, **variable**]
. This can be passed to the connectivity functions, and will internally be padded to:padded_seeds = [[0, 1, -1, -1], [5, 6, 7, 8]]
padded_targets = [[2, 3, 4, -1], [9, -1, -1, -1]]
padded_indices = (padded_seeds, padded_targets)
The padded seeds/targets now have shape
[cons, max_channels]
, which can be stored in the connectivity objects. The consistent shape makes them easier to handle and also compatible with saving.Padding
patterns
For padding
patterns
, I chose the invalid valuenp.nan
.patterns
have shape[2, cons, channels, freqs (, times)]
(where the 2 represents entries for the seeds and targets, respectively). Accordingly, if the number of channels differs across the seed and target within a connection, or across connections,patterns
will be ragged with shape[2, cons, **variable**, freqs (, times)]
.By padding the missing entries with
np.nan
along the channel dimension, we can make the array shape consistent with[2, cons, max_channels, freqs (, times)]
(e.g.[2, 2, 4, freqs (, times)]
using the indices in the above example). Extracting only those valid entries from the patterns for a given connection is very simple, as I show in the example linked above.Other changes
In line with these changes, I updated the relevant unit tests in
test_spectral.py
, and also edited the multivariate connectivity examples to take advantage of the new support for ragged connections. I also included two new helper functionscheck_multivariate_indices
andseed_target_multivariate_indices
(analgous to the existing functions but just for the new multivariate indices format; also with unit tests).Overall I think this approach is quite clean: end-users can give ragged indices to a helper function which will do the padding, or pass them directly to the connectivity methods which will handle everything. Extracting the relevant information from padded
indices
andpatterns
is also trivial. Finally, the internal implementation is simple, so maintenance should not be problematic.Please let me know your thoughts!
Thomas