Skip to content

Commit

Permalink
add tests, error checks, fixed bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Feb 10, 2023
1 parent 388283c commit 5f9b500
Show file tree
Hide file tree
Showing 7 changed files with 411 additions and 224 deletions.
42 changes: 38 additions & 4 deletions mne_connectivity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,12 @@ def _check_topographies_consistency(self, topographies):

def _check_n_components_consistency(self, n_components):
"""Perform n_components input checks."""
# useful for converting back to a tuple when re-loading after saving
if isinstance(n_components, np.ndarray):
n_components = tuple(copy(n_components.tolist()))
elif isinstance(n_components, list):
n_components = tuple(copy(n_components))

if not isinstance(n_components, tuple):
raise TypeError('n_components should be a tuple')

Expand Down Expand Up @@ -1165,6 +1171,7 @@ def save(self, fname):
"""
old_attrs = deepcopy(self.attrs)
self._pad_ragged_attrs()
self._replace_none_n_components()
super(BaseMultivariateConnectivity, self).save(fname)
self.xarray.attrs = old_attrs # resets to non-padded attrs

Expand Down Expand Up @@ -1195,7 +1202,7 @@ def _pad_indices(self, max_n_channels):
padded_indices = [[], []]
for group_i, group in enumerate(self.indices):
for con_i, con in enumerate(group):
if np.nonzero(con == self._pad_val):
if np.count_nonzero(con == self._pad_val):
# this would break the unpadding process when re-loading the
# connectivity object
raise ValueError(
Expand All @@ -1217,7 +1224,7 @@ def _pad_topographies(self, max_n_channels):
longer ragged (i.e. the length of the first dimension of topographies
for each connection equals 'max_n_channels')."""
topos_dims = [2, len(self.indices[0]), max_n_channels, len(self.freqs)]
if 'times' in self.attrs.keys():
if 'times' in self.coords:
topos_dims.append(len(self.times))
padded_topos = np.full(
topos_dims, self._pad_val, dtype=self.topographies[0][0].dtype
Expand All @@ -1229,14 +1236,29 @@ def _pad_topographies(self, max_n_channels):

self.attrs['topographies'] = padded_topos

def _unpad_ragged_attrs(self):
def _replace_none_n_components(self):
"""Replace None values in the n_components attribute with 'n/a', since
None is not supported by netCDF."""
n_components = [[], []]
for group_i, group in enumerate(self.attrs['n_components']):
for con in group:
if con is None:
n_components[group_i].append('n/a')
else:
n_components[group_i].append(con)
self.attrs['n_components'] = tuple(n_components)

def _restore_attrs(self):
"""Unpads ragged attributes of the connectivity object (i.e. indices and
topographies) padded with np.inf so that they could be saved using
topographies) padded with np.inf and restored nested None values in
attributes replaced with 'n/a' so that they could be saved using
HDF5."""
n_padded_channels = self._get_n_padded_channels()
self._unpad_indices(n_padded_channels)
if self.topographies is not None:
self._unpad_topographies(n_padded_channels)

self._restore_non_n_components()

def _get_n_padded_channels(self):
"""Finds the number of channels that have been added when padding the
Expand Down Expand Up @@ -1282,6 +1304,18 @@ def _unpad_topographies(self, n_padded_channels):

self.attrs['topographies'] = unpadded_topos

def _restore_non_n_components(self):
"""Restores None values in the n_components attribute with from
'n/a'."""
n_components = [[], []]
for group_i, group in enumerate(self.attrs['n_components']):
for con in group:
if con == 'n/a':
n_components[group_i].append(None)
else:
n_components[group_i].append(con)
self.attrs['n_components'] = tuple(n_components)


class MultivariateSpectralConnectivity(
SpectralConnectivity, BaseMultivariateConnectivity
Expand Down
21 changes: 11 additions & 10 deletions mne_connectivity/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
MultivariateSpectroTemporalConnectivity)


def _xarray_to_conn(array, cls_func, unpad_ragged_attrs):
def _xarray_to_conn(array, cls_func, restore_attrs):
"""Create connectivity class from xarray.
Parameters
Expand All @@ -20,9 +20,10 @@ def _xarray_to_conn(array, cls_func, unpad_ragged_attrs):
Xarray containing the connectivity data.
cls_func : Connectivity class
The function of the connectivity class to use.
unpad_ragged_attrs : bool
Whether or not to unpad once ragged attributes of the class that were
padded to enable saving with HDF5.
restore_attrs : bool
Whether or not to restore the nature of attributes of the class that
were modified to enable saving with HDF5 (only relevant for multivariate
connectivity classes).
Returns
-------
Expand Down Expand Up @@ -63,9 +64,9 @@ def _xarray_to_conn(array, cls_func, unpad_ragged_attrs):
data=data, names=names, metadata=metadata, **array.attrs
)

# make padded xarray attrs ragged again (for multivariate connectivity only)
if unpad_ragged_attrs:
conn._unpad_ragged_attrs()
# restore attrs modified for saving (for multivariate connectivity only)
if restore_attrs:
conn._restore_attrs()

return conn

Expand Down Expand Up @@ -117,10 +118,10 @@ def read_connectivity(fname):
# checks whether ragged attrs of the class padded for saving need to be
# restored (so far only the case for multivariate connectivity)
if issubclass(cls_func, BaseMultivariateConnectivity):
unpad_ragged_attrs = True
restore_attrs = True
else:
unpad_ragged_attrs = False
restore_attrs = False

# get the data as a new connectivity container
conn = _xarray_to_conn(conn_da, cls_func, unpad_ragged_attrs)
conn = _xarray_to_conn(conn_da, cls_func, restore_attrs)
return conn
68 changes: 44 additions & 24 deletions mne_connectivity/spectral/epochs_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,9 @@ def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1):

self.mic_scores = copy.deepcopy(self.con_scores)
self.mim_scores = copy.deepcopy(self.con_scores)
self.topographies = np.empty((2, n_cons), dtype=object)

def compute_con(
self, seeds, targets, n_seed_components, n_target_components, n_epochs,
form_name
self, seeds, targets, n_components, n_epochs, form_name
):
"""Computes MIC and/or MIM between sets of signals."""
self._sort_form_name(form_name)
Expand All @@ -208,7 +206,9 @@ def compute_con(
n_times = csd.shape[0]

con_i = 0
for seed_idcs, target_idcs in zip(seeds, targets):
for seed_idcs, target_idcs, n_seed_comps, n_target_comps in zip(
seeds, targets, n_components[0], n_components[1]
):
self._log_connection_number(con_i, f'coherence ({form_name})')

n_seeds = len(seed_idcs)
Expand All @@ -221,8 +221,7 @@ def compute_con(
C_bar, U_bar_aa, U_bar_bb = self._cross_spectra_svd(
csd=C,
n_seeds=n_seeds,
n_seed_components=n_seed_components[con_i],
n_target_components=n_target_components[con_i],
n_components=(n_seed_comps, n_target_comps)
)

# Eqs. 3 & 4
Expand Down Expand Up @@ -256,9 +255,12 @@ def _sort_form_name(self, form_name):
self.compute_mic = True
else: # only MIM left
self.compute_mim = True

if self.compute_mic:
self.topographies = np.empty((2, self.n_cons), dtype=object)

def _cross_spectra_svd(
self, csd, n_seeds, n_seed_components, n_target_components
self, csd, n_seeds, n_components
):
"""Performs dimensionality reduction on a cross-spectral density using
singular value decomposition (SVD)."""
Expand All @@ -270,24 +272,32 @@ def _cross_spectra_svd(
C_ba = csd[:, :, n_seeds:, :n_seeds]

# Eq. 32
if n_seed_components is not None:
if n_components[0] is not None:
U_aa = np.linalg.svd(np.real(C_aa), full_matrices=False)[0]
U_bar_aa = U_aa[:, :, :, :n_seed_components]
U_bar_aa = U_aa[:, :, :, :n_components[0]]
else:
U_bar_aa = np.broadcast_to(np.identity(n_seeds), (n_times, self.n_freqs)+(n_seeds, n_seeds))
if n_target_components is not None:
U_bar_aa = np.broadcast_to(
np.identity(n_seeds),
(n_times, self.n_freqs) + (n_seeds, n_seeds)
)
if n_components[1] is not None:
U_bb = np.linalg.svd(np.real(C_bb), full_matrices=False)[0]
U_bar_bb = U_bb[:, :, :, :n_target_components]
U_bar_bb = U_bb[:, :, :, :n_components[1]]
else:
U_bar_bb = np.broadcast_to(np.identity(n_targets), (n_times, self.n_freqs)+(n_targets, n_targets))
U_bar_bb = np.broadcast_to(
np.identity(n_targets),
(n_times, self.n_freqs) + (n_targets, n_targets)
)

# Eq. 33
C_bar_aa = U_bar_aa.transpose(0, 1, 3, 2) @ (C_aa @ U_bar_aa)
C_bar_ab = U_bar_aa.transpose(0, 1, 3, 2) @ (C_ab @ U_bar_bb)
C_bar_bb = U_bar_bb.transpose(0, 1, 3, 2) @ (C_bb @ U_bar_bb)
C_bar_ba = U_bar_bb.transpose(0, 1, 3, 2) @ (C_ba @ U_bar_aa)
C_bar = np.append(
np.append(C_bar_aa, C_bar_ab, axis=3), np.append(C_bar_ba, C_bar_bb, axis=3), axis=2
np.append(C_bar_aa, C_bar_ab, axis=3),
np.append(C_bar_ba, C_bar_bb, axis=3),
axis=2
)

return C_bar, U_bar_aa, U_bar_bb
Expand All @@ -313,12 +323,13 @@ def _compute_e(self, csd, n_seeds):
)
T = T.transpose(1, 0, 2, 3)

if not np.all(np.isreal(T)):
if not np.isreal(T).all() or not np.isfinite(T).all():
raise ValueError(
'the transformation matrix of the data must be real-valued, '
'but it is not; check that you are using full-rank data or '
'specifying an appropriate number of components for the seeds '
'and targets that is less than or equal to their ranks'
'the transformation matrix of the data must be real-valued and '
'contain no NaN or infinity values; check that you are using '
'full rank data or specify an appropriate number of components '
'for the seeds and targets that is less than or equal to their '
'ranks'
)
T = np.real(T)

Expand Down Expand Up @@ -549,21 +560,30 @@ def autocov_to_full_var(self, autocov):
Ref.: Whittle P., 1963. Biometrika, DOI: 10.1093/biomet/50.1-2.129.
"""
if np.any(np.linalg.det(autocov) == 0):
raise ValueError(
'the autocovariance matrix is singular; make sure you are '
'using only full rank data, or specify an appropriate number '
'of components for the seeds and targets that is less than or '
'equal to their ranks'
)

A_f, V = self.whittle_lwr_recursion(autocov)

if not np.isfinite(A_f).all():
raise ValueError(
"Some or all VAR model coefficients are infinite or NaNs. "
"Please check the data you are computing Granger causality on."
'some or all VAR model coefficients are infinite or NaNs; '
'please check the data you are computing connectivity on'
)

try:
np.linalg.cholesky(V)
except np.linalg.linalg.LinAlgError as np_error:
raise ValueError(
"The residuals' covariance matrix is not positive-definite. "
"Make sure you are computing Granger causality only on data "
"that is full rank."
'the residuals covariance matrix is not positive-definite; '
'make sure you are using only full rank data, or specify an '
'appropriate number of components for the seeds and targets '
'that is less than or equal to their ranks'
) from np_error

return A_f, V
Expand Down
Loading

0 comments on commit 5f9b500

Please sign in to comment.