From 93b2f83aaf068fca0e3b7eae2c2d77d14f8ad87c Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 6 Mar 2024 10:58:15 +0100 Subject: [PATCH 1/2] Create patterns array only when needed --- .../spectral/epochs_multivariate.py | 1 + mne_connectivity/spectral/time.py | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index 7265ef9b..2881cc00 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -819,3 +819,4 @@ class _GCTREst(_GCEstBase): _multivariate_methods = ["mic", "mim", "gc", "gc_tr"] _gc_methods = ["gc", "gc_tr"] +_patterns_methods = ["mic"] # methods with spatial patterns diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index b5d5a648..dd098358 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -21,6 +21,7 @@ _check_rank_input, _gc_methods, _multivariate_methods, + _patterns_methods, ) from .smooth import _create_kernel, _smooth_spectra @@ -541,10 +542,13 @@ def spectral_connectivity_time( conn_patterns = dict() for m in method: conn[m] = np.zeros((n_epochs, n_cons, n_freqs)) - # patterns shape of [epochs x seeds/targets x cons x channels x freqs] - conn_patterns[m] = np.full( - (n_epochs, 2, n_cons, max_n_channels, n_freqs), np.nan - ) + if m in _patterns_methods: + # patterns shape of [epochs x seeds/targets x cons x channels x freqs] + conn_patterns[m] = np.full( + (n_epochs, 2, n_cons, max_n_channels, n_freqs), np.nan + ) + else: + conn_patterns[m] = None logger.info("Connectivity computation...") # parameters to pass to the connectivity function @@ -577,12 +581,10 @@ def spectral_connectivity_time( scores, patterns = _spectral_connectivity(data[epoch_idx], **call_params) for m in method: conn[m][epoch_idx] = np.stack(scores[m], axis=0) - if multivariate_con and patterns[m] is not None: + if patterns[m] is not None: conn_patterns[m][epoch_idx] = np.stack(patterns[m], axis=0) for m in method: - if np.isnan(conn_patterns[m]).all(): - conn_patterns[m] = None - else: + if conn_patterns[m] is not None: # transpose to [seeds/targets x epochs x cons x channels x freqs] conn_patterns[m] = conn_patterns[m].transpose((1, 0, 2, 3, 4)) From a79ad28cce21b67771c72cc5db21c04d1ae4f9cb Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 6 Mar 2024 19:27:26 +0100 Subject: [PATCH 2/2] Add explanatory comment from code review Co-authored-by: Adam Li --- mne_connectivity/spectral/time.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index dd098358..58b49cb1 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -542,6 +542,7 @@ def spectral_connectivity_time( conn_patterns = dict() for m in method: conn[m] = np.zeros((n_epochs, n_cons, n_freqs)) + # prevent allocating memory for a huge array if not required if m in _patterns_methods: # patterns shape of [epochs x seeds/targets x cons x channels x freqs] conn_patterns[m] = np.full(