Skip to content

Commit

Permalink
Refact: cleaned up spectral estimation code.
Browse files Browse the repository at this point in the history
  • Loading branch information
cgohil8 committed Jun 14, 2024
1 parent 275dd6f commit 5862fe6
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions osl_dynamics/analysis/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,10 +1406,10 @@ def _welch(
# Rescale PSDs to account for the number of time points
# each state was active
fo = np.sum(stc, axis=0) / stc.shape[0]
for i, (psd_, fo_) in enumerate(zip(psd, fo)):
psd_ /= fo_
if np.isnan(psd_).any():
psd[i] = np.nan_to_num(psd_) # zero out nan values
for i in range(len(fo)):
psd[i] /= fo[i]
if np.isnan(psd[i]).any():
psd[i] = np.nan_to_num(psd[i]) # zero out nan values
_logger.warn(
"PSD contains NaN values. This may indicate a potentially "
"poor HMM fit. You should consider running the model again "
Expand Down Expand Up @@ -1714,7 +1714,6 @@ def _multitaper(

if calc_coh:
# Create a channels by channels matrix for cross PSDs
n_channels = data.shape[-1]
n_freq = p.shape[-1]
cpsd = np.empty(
[n_channels, n_channels, n_freq],
Expand All @@ -1734,10 +1733,10 @@ def _multitaper(
# Rescale PSDs to account for the number of time points
# each state was active
fo = np.sum(stc, axis=0) / stc.shape[0]
for i, (psd_, fo_) in enumerate(zip(psd, fo)):
psd_ /= fo_
if np.isnan(psd_).any():
psd[i] = np.nan_to_num(psd_) # zero out nan values
for i in range(len(fo)):
psd[i] /= fo[i]
if np.isnan(psd[i]).any():
psd[i] = np.nan_to_num(psd[i]) # zero out nan values
_logger.warn(
"PSD contains NaN values. This may indicate a potentially "
"poor HMM fit. You should consider running the model again "
Expand Down

0 comments on commit 5862fe6

Please sign in to comment.