diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index 7265ef9b..2881cc00 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -819,3 +819,4 @@ class _GCTREst(_GCEstBase): _multivariate_methods = ["mic", "mim", "gc", "gc_tr"] _gc_methods = ["gc", "gc_tr"] +_patterns_methods = ["mic"] # methods with spatial patterns diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index b5d5a648..58b49cb1 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -21,6 +21,7 @@ _check_rank_input, _gc_methods, _multivariate_methods, + _patterns_methods, ) from .smooth import _create_kernel, _smooth_spectra @@ -541,10 +542,14 @@ def spectral_connectivity_time( conn_patterns = dict() for m in method: conn[m] = np.zeros((n_epochs, n_cons, n_freqs)) - # patterns shape of [epochs x seeds/targets x cons x channels x freqs] - conn_patterns[m] = np.full( - (n_epochs, 2, n_cons, max_n_channels, n_freqs), np.nan - ) + # prevent allocating memory for a huge array if not required + if m in _patterns_methods: + # patterns shape of [epochs x seeds/targets x cons x channels x freqs] + conn_patterns[m] = np.full( + (n_epochs, 2, n_cons, max_n_channels, n_freqs), np.nan + ) + else: + conn_patterns[m] = None logger.info("Connectivity computation...") # parameters to pass to the connectivity function @@ -577,12 +582,10 @@ def spectral_connectivity_time( scores, patterns = _spectral_connectivity(data[epoch_idx], **call_params) for m in method: conn[m][epoch_idx] = np.stack(scores[m], axis=0) - if multivariate_con and patterns[m] is not None: + if patterns[m] is not None: conn_patterns[m][epoch_idx] = np.stack(patterns[m], axis=0) for m in method: - if np.isnan(conn_patterns[m]).all(): - conn_patterns[m] = None - else: + if conn_patterns[m] is not None: # transpose to [seeds/targets x epochs x cons x channels x freqs] conn_patterns[m] = conn_patterns[m].transpose((1, 0, 2, 3, 4))