Skip to content

Commit

Permalink
Create patterns array only when needed
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Mar 6, 2024
1 parent 196a7fc commit 93b2f83
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
1 change: 1 addition & 0 deletions mne_connectivity/spectral/epochs_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,3 +819,4 @@ class _GCTREst(_GCEstBase):

_multivariate_methods = ["mic", "mim", "gc", "gc_tr"]
_gc_methods = ["gc", "gc_tr"]
_patterns_methods = ["mic"] # methods with spatial patterns
18 changes: 10 additions & 8 deletions mne_connectivity/spectral/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_check_rank_input,
_gc_methods,
_multivariate_methods,
_patterns_methods,
)
from .smooth import _create_kernel, _smooth_spectra

Expand Down Expand Up @@ -541,10 +542,13 @@ def spectral_connectivity_time(
conn_patterns = dict()
for m in method:
conn[m] = np.zeros((n_epochs, n_cons, n_freqs))
# patterns shape of [epochs x seeds/targets x cons x channels x freqs]
conn_patterns[m] = np.full(
(n_epochs, 2, n_cons, max_n_channels, n_freqs), np.nan
)
if m in _patterns_methods:
# patterns shape of [epochs x seeds/targets x cons x channels x freqs]
conn_patterns[m] = np.full(
(n_epochs, 2, n_cons, max_n_channels, n_freqs), np.nan
)
else:
conn_patterns[m] = None
logger.info("Connectivity computation...")

# parameters to pass to the connectivity function
Expand Down Expand Up @@ -577,12 +581,10 @@ def spectral_connectivity_time(
scores, patterns = _spectral_connectivity(data[epoch_idx], **call_params)
for m in method:
conn[m][epoch_idx] = np.stack(scores[m], axis=0)
if multivariate_con and patterns[m] is not None:
if patterns[m] is not None:
conn_patterns[m][epoch_idx] = np.stack(patterns[m], axis=0)
for m in method:
if np.isnan(conn_patterns[m]).all():
conn_patterns[m] = None
else:
if conn_patterns[m] is not None:
# transpose to [seeds/targets x epochs x cons x channels x freqs]
conn_patterns[m] = conn_patterns[m].transpose((1, 0, 2, 3, 4))

Expand Down

0 comments on commit 93b2f83

Please sign in to comment.