-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] Multiple improvements to spectral_connectivity_time: ciPLV, and efficient computation of multiple metrics #115
Conversation
All connectivity measures are now computed with only a single computation of pairwise cross spectrum.
In some scenarios, users might want to specify the frequencies for time-frequency decomposition also when using multitapering. These changes allow users to specify the 'freqs' parameter to override the automatically determined frequencies.
This adds the option to use the edges of the signal at each epoch as padding. The purpose of this is to avoid edge effects generated by the time-frequency transformation methods.
Sym is not a parameter of dpss_windows. (But is one of the underlying scipy.signal.dpss)
This change will skip the rendering of the connectivity computation progress bar if the logging level is not DEBUG. This is in line with MNE-Python, where progress bars are not shown at INFO or higher logging levels. Rendering the progress bar regardless of logging levels has the potential to cause unnecessary clutter in users' log files.
Add a better description of the method + style nitpicks.
It looks like the CIs are not happy about the docs and loading the Matplotlib Qt backend fails with Python 3.7. I can't check what's wrong with the docs it seems. Building locally, it complains about |
mne_connectivity/spectral/time.py
Outdated
if padding: | ||
pad_idx = int(np.floor(padding * sfreq / decim)) | ||
out = out[..., pad_idx:-pad_idx] | ||
weights = weights[..., pad_idx:-pad_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mind also adding a test for padding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
mne_connectivity/spectral/time.py
Outdated
if m == 'coh': | ||
s_xx = psd[w_x] | ||
s_yy = psd[w_y] | ||
coh = np.abs(s_xy.mean(axis=-1, keepdims=True)) / \ | ||
np.sqrt(s_xx.mean(axis=-1, keepdims=True) * | ||
s_yy.mean(axis=-1, keepdims=True)) | ||
out.append(coh) | ||
|
||
if m == 'plv': | ||
dphi_mean = dphi.mean(axis=-1, keepdims=True) | ||
plv = np.abs(dphi_mean) | ||
out.append(plv) | ||
|
||
if m == 'ciplv': | ||
rplv = np.abs(np.mean(np.real(dphi), axis=-1, keepdims=True)) | ||
iplv = np.abs(np.mean(np.imag(dphi), axis=-1, keepdims=True)) | ||
ciplv = iplv / (np.sqrt(1 - rplv ** 2)) | ||
out.append(ciplv) | ||
|
||
if m == 'pli': | ||
pli = np.abs(np.mean(np.sign(np.imag(s_xy)), | ||
axis=-1, keepdims=True)) | ||
out.append(pli) | ||
|
||
if m == 'wpli': | ||
con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True)) | ||
con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True) | ||
wpli = con_num / con_den | ||
out.append(wpli) | ||
|
||
if m == 'cs': | ||
out.append(s_xy) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT about keeping the basic pairwise "computation" functions as separate functions, so that way they can be self-documented in their docstring? E.g. keeping all the old functions?
And then to add say a new spectral-connectivity function, one would have to: i) add a new private function e.g. def _newpli(x, y)
and then ii) update _parallel_con
to add this option.
Might be easier to read. But also not 100% necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's an option for sure. However, it should be done such that we avoid computing the cross-spectral density multiple times. This is slow due to the potentially large size of the arrays according to some profiling I did.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will take a look to see if this is possible then.
I wonder if it's possible to keep even the same "spectral-conn" base functions across 'epochs' and 'time' computation. I.e. there are the classes inside spectral_conn_epochs.py
which also compute certain spectral connectivity. They seem to all operate on a pairwise nature, so perhaps we can even consolidate into one place, so then adding one connectivity function adds it for both 'epochs' and 'time'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm afraid it wouldn't be practical to use the classes from spectral_connectivity_epochs
. Those estimate connectivity as mean over epochs with the assumption that the connectivity measure is stationary within an Epoch. An estimator's accumulate
method is therefore called only once per Epoch. Cross-spectral density is computed in a single time window spanning the whole Epoch (with the exception of cwt_morlet
mode).
In spectral_connectivity_time
the pairwise computation involves summing over the time axis and multitaper
actually means convolution with a DPSS-tapered sinusoid (similar to Gaussian tapered sinusoid for morlet
. Summing over the time axis using something like the accumulate
function in epochs
would necessitate a for-loop over time. There's also the possibility of fetching connectivity over time separately for each epoch, which of course doesn't exist in epochs
.
In my opinion, accommodating both within the same classes would probably lead to major refactoring of spectral_connectivity_epochs
and potentially make the code less readable.
Sorry if the explanation is too basic, I'm mostly writing it down to keep it straight for myself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT about doing it this way?
def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, n_jobs,
verbose, total, faverage, weights):
"""Compute spectral connectivity in parallel.
Input signal w is of shape (n_chans, n_tapers, n_freqs, n_times)."""
if 'coh' in method:
# psd
if weights is not None:
psd = weights * w
psd = psd * np.conj(psd)
psd = psd.real.sum(axis=1)
psd = psd * 2 / (weights * weights.conj()).real.sum(axis=0)
else:
psd = w.real ** 2 + w.imag ** 2
psd = np.squeeze(psd, axis=1)
# smooth
psd = _smooth_spectra(psd, kernel)
def pairwise_con(w_x, w_y):
# csd
if weights is not None:
s_xy = np.sum(weights * w[w_x] * np.conj(weights * w[w_y]), axis=0)
s_xy = s_xy * 2 / (weights * np.conj(weights)).real.sum(axis=0)
else:
s_xy = w[w_x] * np.conj(w[w_y])
s_xy = np.squeeze(s_xy, axis=0)
s_xy = _smooth_spectra(s_xy, kernel)
out = []
conn_func = {'plv': _plv, 'ciplv': _ciplv, 'pli': _pli, 'wpli': _wpli,
'coh': _coh, 'cs': _cs}
for m in method:
if m == 'coh':
s_xx = psd[w_x]
s_yy = psd[w_y]
out.append(conn_func[m](s_xx, s_yy, s_xy))
else:
out.append(conn_func[m](s_xy))
for i, _ in enumerate(out):
# mean inside frequency sliding window (if needed)
if isinstance(foi_idx, np.ndarray) and faverage:
out[i] = _foi_average(out[i], foi_idx)
# squeeze time dimension
out[i] = out[i].squeeze(axis=-1)
return out
# only show progress if verbosity level is DEBUG
if verbose != 'DEBUG' and verbose != 'debug' and verbose != 10:
total = None
# define the function to compute in parallel
parallel, p_fun, n_jobs = parallel_func(
pairwise_con, n_jobs=n_jobs, verbose=verbose, total=total)
return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx))
def _plv(s_xy):
s_xy = s_xy / np.abs(s_xy)
plv = np.abs(s_xy.mean(axis=-1, keepdims=True))
return plv
def _ciplv(s_xy):
s_xy = s_xy / np.abs(s_xy)
rplv = np.abs(np.mean(np.real(s_xy), axis=-1, keepdims=True))
iplv = np.abs(np.mean(np.imag(s_xy), axis=-1, keepdims=True))
ciplv = iplv / (np.sqrt(1 - rplv ** 2))
return ciplv
def _pli(s_xy):
pli = np.abs(np.mean(np.sign(np.imag(s_xy)),
axis=-1, keepdims=True))
return pli
def _wpli(s_xy):
con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True))
con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True)
wpli = con_num / con_den
return wpli
def _coh(s_xx, s_yy, s_xy):
con_num = np.abs(s_xy.mean(axis=-1, keepdims=True))
con_den = np.sqrt(s_xx.mean(axis=-1, keepdims=True) *
s_yy.mean(axis=-1, keepdims=True))
coh = con_num / con_den
return coh
def _cs(s_xy):
return s_xy.mean(axis=-1, keepdims=True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that looks better and more similar to the epochs file.
Idr but can we put the pairwise function outside? Nested functions imo are harder to read. But if that breaks the paralleliation via joblib, then nvm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll try if the pairwise function works outside _parallel_con
. Fully agree that nested functions are ugly.
mne_connectivity/spectral/time.py
Outdated
def _compute_freqs(n_times, sfreq, freqs, mode): | ||
from scipy.fft import rfftfreq | ||
# get frequencies of interest for the different modes | ||
if freqs is not None: | ||
if any(freqs > (sfreq / 2.)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is defined already in spectral.epochs.py
. Can we use that one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function in epochs.py
would always return rfftfreq(n_times, 1./sfreq)
if multitaper
is used. This results in a very dense frequency grid, making computation too slow to be practical for wide frequency bands. Therefore, I would like to allow freqs
to be optionally set by the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add freqs
into the existing function and then use it?
Having one private function makes the code more robust, easier to maintain, and less likely to introduce runtime bugs in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically yes, but then we would have to prevent the use of cwt_freqs
together with multitaper
in spectral_connectivity_epochs
. In that case freqs
must be equal to rfftfreq(n_times, 1./sfreq)
as that's the shape of the output of _csd_from_mt
and _psd_from_mt
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that is fine. So, we would check if freqs is not None and mode == 'multitaper'
and raise an error.
The function is essentially doing the same thing regardless of which one it's being called from.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be the case only in spectral_connectivity_epochs
(where multitaper frequencies must be defined by rfftfreq
for the method to work). In spectral_connectivity_time
there isn't really much reason to even use rfftfreq
to obtain frequencies as it uses a different (time-resolved) multitaper implementation. If we require freqs
anyway with morlet
, it might make sense to go back to requiring them for multitaper
as well given that, as far as I understand, in the time-resolved tfr
implementation the difference between morlet
and multitaper
is only in the type of wavelet used (+ multiple tapers with multitaper).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a solution could be to add the connectivity paradigm (epochs vs time) to the parameters of _compute_freqs
. Although the reasoning in my previous comment would still favor requiring freqs
in Multitaper mode as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I think requiring freqs
could be fine. I'm not 100% sure I'm following wrt all the different possible combinations of parameters here, so I can take a look after you push up the proposed change?
Should be fixed now on main. |
Individual functions for each connectivity methods.
This change makes joblib happy to do multithreading.
Feel free to ping me whenever you need me to take a look again! |
Happy new years @ruuskas ! Just wanted to check in on this. It would be great to get this in and make a new release w/ all your cool improvements! |
Happy new year @adam2392 ! I have been focusing on other things so this has been lagging. I'll try to get the necessary changes done this week (maybe today if there are no blockers). |
The user is required to specify the wavelet central frequencies in both multitaper and cwt_morlet tfr mode. The reasoning is that the underlying tfr implementations are very similar. This is in contrast to spectral_connectivity_epochs, where multitaper assumes that the spectrum is stationary and therefore no wavelets are used.
Hi @adam2392 ! This should now be ready for review again. The Sphinx is not happy as there seems to be an issue with indentation in the docstring, which I couldn't figure out without changing config. |
The sphinx error:
is due to a change in sphinx 6.0 and is currently being addressed upstream in the website theme. Temporary workaround is to pin sphinx to <6.0 |
The Sphinx error I get locally is different. I can copy it verbatim tomorrow, but essentially "unexpected indentation" in the docstring. Indentation looks fine where it points to though. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Some minor changes and clean up and this should be good to go soon.
|
||
Parameters | ||
---------- | ||
data : array_like, shape (n_epochs, n_signals, n_times) | Epochs | ||
The data from which to compute connectivity. | ||
freqs : array_like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can be int
, or float
right? If so, what is the expected behavior? It just computes the connectivity at that frequency?
Signed-off-by: Adam Li <[email protected]>
Co-authored-by: Adam Li <[email protected]>
…nto spectral_time
@ruuskas I fixed the indentation error and other sphinx issue. You have to git pull to sync changes. |
Hi @adam2392! I addressed all the things you pointed out and pushed changes. WDYT? |
Co-authored-by: Adam Li <[email protected]>
def _pairwise_con(w, psd, x, y, method, kernel, foi_idx, | ||
faverage, weights): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even tho this is a private function, do you mind adding a basic docstring mainly for internal documentation purposes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One minor nitpick, and then this LGTM for merging!
Thanks @ruuskas! Great addition and contribution. |
Thanks @adam2392! |
… efficient computation of multiple metrics (mne-tools#115) * Add ciPLV: Add the corrected imaginary Phase-Locking-Value into the list of available connectivity metrics. * Speed up computation: All connectivity measures are now computed with only a single computation of pairwise cross spectrum. * Add the option to specify freqs in all modes: In some scenarios, users might want to specify the frequencies for time-frequency decomposition also when using multitapering. These changes allow users to specify the 'freqs' parameter to override the automatically determined frequencies. * BUG: Average over CSD instead of connectivity * Add option to use part of signal as padding: This adds the option to use the edges of the signal at each epoch as padding. The purpose of this is to avoid edge effects generated by the time-frequency transformation methods. * Fix test bug, use 'freqs' instead of 'cwt_freqs' * Fix bug with dpss windows: Sym is not a parameter of dpss_windows. (But is one of the underlying scipy.signal.dpss) * Only show progress bar if verbosity level is DEBUG: This change will skip the rendering of the connectivity computation progress bar if the logging level is not DEBUG. This is in line with MNE-Python, where progress bars are not shown at INFO or higher logging levels. Rendering the progress bar regardless of logging levels has the potential to cause unnecessary clutter in users' log files. * Require freqs in all tfr modes The user is required to specify the wavelet central frequencies in both multitaper and cwt_morlet tfr mode. The reasoning is that the underlying tfr implementations are very similar. This is in contrast to spectral_connectivity_epochs, where multitaper assumes that the spectrum is stationary and therefore no wavelets are used. * Require mne>=1.3 Signed-off-by: Adam Li <[email protected]> Co-authored-by: Adam Li <[email protected]>
PR Description
This PR addresses improvements mentioned in #104. I'm reposting this with the Git log now cleaned up.
Merge checklist
Maintainer, please confirm the following before merging: