diff --git a/osl_dynamics/analysis/spectral.py b/osl_dynamics/analysis/spectral.py index bc1bd96c..40094a06 100644 --- a/osl_dynamics/analysis/spectral.py +++ b/osl_dynamics/analysis/spectral.py @@ -1406,10 +1406,10 @@ def _welch( # Rescale PSDs to account for the number of time points # each state was active fo = np.sum(stc, axis=0) / stc.shape[0] - for i, (psd_, fo_) in enumerate(zip(psd, fo)): - psd_ /= fo_ - if np.isnan(psd_).any(): - psd[i] = np.nan_to_num(psd_) # zero out nan values + for i in range(len(fo)): + psd[i] /= fo[i] + if np.isnan(psd[i]).any(): + psd[i] = np.nan_to_num(psd[i]) # zero out nan values _logger.warn( "PSD contains NaN values. This may indicate a potentially " "poor HMM fit. You should consider running the model again " @@ -1714,7 +1714,6 @@ def _multitaper( if calc_coh: # Create a channels by channels matrix for cross PSDs - n_channels = data.shape[-1] n_freq = p.shape[-1] cpsd = np.empty( [n_channels, n_channels, n_freq], @@ -1734,10 +1733,10 @@ def _multitaper( # Rescale PSDs to account for the number of time points # each state was active fo = np.sum(stc, axis=0) / stc.shape[0] - for i, (psd_, fo_) in enumerate(zip(psd, fo)): - psd_ /= fo_ - if np.isnan(psd_).any(): - psd[i] = np.nan_to_num(psd_) # zero out nan values + for i in range(len(fo)): + psd[i] /= fo[i] + if np.isnan(psd[i]).any(): + psd[i] = np.nan_to_num(psd[i]) # zero out nan values _logger.warn( "PSD contains NaN values. This may indicate a potentially " "poor HMM fit. You should consider running the model again "